Merge remote-tracking branch 'aosp/upstream-master' into rebase_tf

Bug: 113615477
Test: mm
Test: NeuralNetworkTest_static
Change-Id: I6ff1cdf69034c37287e214fbc972fdf0de569c53
diff --git a/RELEASE.md b/RELEASE.md
index 763ef3b..bdc2379 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -1,3 +1,9 @@
+# Release 1.10.1
+## Bug Fixes and Other Changes
+
+* `tf.keras`:
+  * Fixing keras on Cloud TPUs. No new binaries will be built for Windows.
+
 # Release 1.10.0
 
 ## Major Features And Improvements
diff --git a/configure.py b/configure.py
index 361bd47..52a5137 100644
--- a/configure.py
+++ b/configure.py
@@ -852,7 +852,7 @@
 
     # Reset and retry
     print('Invalid path to CUDA %s toolkit. %s cannot be found' %
-          (tf_cuda_version, cuda_toolkit_path_full))
+          (tf_cuda_version, cuda_toolkit_paths_full))
     environ_cp['TF_CUDA_VERSION'] = ''
     environ_cp['CUDA_TOOLKIT_PATH'] = ''
 
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index b5e0a4e..386e009 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -12,6 +12,7 @@
     # The leakr files are used by //third_party/cloud_tpu.
     "leakr_badwords.dic",
     "leakr_badfiles.dic",
+    "leakr_file_type_recipe.ftrcp",
 ])
 
 load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
@@ -23,6 +24,11 @@
     "//tensorflow/python/tools/api/generator:api_gen.bzl",
     "gen_api_init_files",  # @unused
 )
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files")
+load(
+    "//tensorflow/python/tools/api/generator:api_init_files.bzl",
+    "TENSORFLOW_API_INIT_FILES",  # @unused
+)
 load(
     "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
     "TENSORFLOW_API_INIT_FILES_V1",  # @unused
@@ -32,6 +38,11 @@
     "if_ngraph",
 )
 
+# @unused
+TENSORFLOW_API_INIT_FILES_V2 = (
+    TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
+
 # Config setting used when building for products
 # which requires restricted licenses to be avoided.
 config_setting(
@@ -427,12 +438,20 @@
     visibility = ["//visibility:public"],
 )
 
+# This flag specifies whether TensorFlow 2.0 API should be built instead
+# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
+config_setting(
+    name = "api_version_2",
+    define_values = {"tf_api_version": "2"},
+)
+
 package_group(
     name = "internal",
     packages = [
         "-//third_party/tensorflow/python/estimator",
         "//learning/meta_rank/...",
         "//tensorflow/...",
+        "//tensorflow_estimator/...",
         "//tensorflow_fold/llgtm/...",
         "//third_party/py/tensor2tensor/...",
     ],
@@ -590,13 +609,39 @@
 )
 
 gen_api_init_files(
-    name = "tensorflow_python_api_gen",
+    name = "tf_python_api_gen_v1",
     srcs = ["api_template.__init__.py"],
     api_version = 1,
+    output_dir = "_api/v1/",
     output_files = TENSORFLOW_API_INIT_FILES_V1,
+    output_package = "tensorflow._api.v1",
     root_init_template = "api_template.__init__.py",
 )
 
+gen_api_init_files(
+    name = "tf_python_api_gen_v2",
+    srcs = ["api_template.__init__.py"],
+    api_version = 2,
+    compat_api_versions = [1],
+    output_dir = "_api/v2/",
+    output_files = TENSORFLOW_API_INIT_FILES_V2,
+    output_package = "tensorflow._api.v2",
+    root_init_template = "api_template.__init__.py",
+)
+
+genrule(
+    name = "root_init_gen",
+    srcs = select({
+        "api_version_2": [":tf_python_api_gen_v2"],
+        "//conditions:default": [":tf_python_api_gen_v1"],
+    }),
+    outs = ["__init__.py"],
+    cmd = select({
+        "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+        "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+    }),
+)
+
 py_library(
     name = "tensorflow_py",
     srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
@@ -611,7 +656,10 @@
 
 py_library(
     name = "tensorflow_py_no_contrib",
-    srcs = [":tensorflow_python_api_gen"],
+    srcs = select({
+        "api_version_2": [":tf_python_api_gen_v2"],
+        "//conditions:default": [":tf_python_api_gen_v1"],
+    }) + [":root_init_gen"],
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = ["//tensorflow/python:no_contrib"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 779f65d..53a72b8 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -18,11 +18,12 @@
 from __future__ import division
 from __future__ import print_function
 
+import os as _os
+
 # pylint: disable=g-bad-import-order
 from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
 
 try:
-  import os  # pylint: disable=g-import-not-at-top
   # Add `estimator` attribute to allow access to estimator APIs via
   # "tf.estimator..."
   from tensorflow.python.estimator.api import estimator  # pylint: disable=g-import-not-at-top
@@ -30,9 +31,8 @@
   # Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
   # style imports.
   from tensorflow.python.estimator import api as estimator_api  # pylint: disable=g-import-not-at-top
-  __path__ += [os.path.dirname(estimator_api.__file__)]
+  __path__ += [_os.path.dirname(estimator_api.__file__)]
   del estimator_api
-  del os
 except (ImportError, AttributeError):
   print('tf.estimator package not installed.')
 
@@ -45,6 +45,12 @@
 from tensorflow.python.platform import flags  # pylint: disable=g-import-not-at-top
 app.flags = flags  # pylint: disable=undefined-variable
 
+# Make sure directory containing top level submodules is in
+# the __path__ so that "from tensorflow.foo import bar" works.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__))  # pylint: disable=undefined-variable
+if _tf_api_dir not in __path__:
+  __path__.append(_tf_api_dir)
+
 del absolute_import
 del division
 del print_function
@@ -54,6 +60,12 @@
 # must come from this module. So python adds these symbols for the
 # resolution to succeed.
 # pylint: disable=undefined-variable
-del python
-del core
+try:
+  del python
+  del core
+except NameError:
+  # Don't fail if these modules are not available.
+  # For e.g. we are using this file for compat.v1 module as well and
+  # 'python', 'core' directories are not under compat/v1.
+  pass
 # pylint: enable=undefined-variable
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 109b3b3..43c279b 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -204,6 +204,7 @@
         "//tensorflow:darwin": ["-headerpad_max_install_names"],
         "//conditions:default": [],
     }),
+    tags = ["noasan"],
     # We must ensure that the dependencies can be dynamically linked since
     # the shared library must be able to use core:framework.
     # linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 173bbea..79811ce 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -39,6 +39,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 69b3ffe..c195c9e 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -17,6 +17,7 @@
 
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -79,6 +80,18 @@
   auto* gpu_options = config.mutable_gpu_options();
   gpu_options->set_allow_growth(gpu_memory_allow_growth);
 
+  // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
+  // threadpool, so that we avoid the possibility of running the runner_ in the
+  // threadpool of GPU event mgr, as that can trigger more callbacks to be
+  // scheduled on that same threadpool, causing a deadlock in cases where the
+  // caller of event_mgr->ThenExecute() blocks on the completion of the callback
+  // (as in the case of ConstOp kernel creation on GPU, which involves copying a
+  // CPU tensor to GPU).
+  // Setting a larger thread pool does not help with the Swift caller, as we use
+  // a different TFE context for each thread of execution (for running graph
+  // functions, and their send/recvs corountines).
+  config.set_inter_op_parallelism_threads(1);
+
   TF_Buffer* ret = TF_NewBuffer();
   TF_CHECK_OK(MessageToBuffer(config, ret));
   return ret;
@@ -8494,3 +8507,201 @@
                 /*run_metadata*/ nullptr, status);
   VLOG(1) << "Enqueuing is done.";
 }
+
+TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
+                                          TF_Status* status) {
+  auto* opts = TFE_NewContextOptions();
+
+  // Reduce GPU memory allocation, and set appropriate config options for TFE
+  // context.
+  auto* config =
+      TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
+  TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
+  if (!status->status.ok()) {
+    CHECK(!config);
+    TFE_DeleteContextOptions(opts);
+    return nullptr;
+  }
+
+  auto* ctx = TFE_NewContextFromSession(opts, session, status);
+  TF_DeleteBuffer(config);
+  TFE_DeleteContextOptions(opts);
+  return ctx;
+}
+
+// TODO: retrieve the device string via TFE_ContextListDevices()
+static const char DEFAULT_CPU_DEVICE[] =
+    "/job:localhost/replica:0/task:0/device:CPU:0";
+
+static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
+                                        int tensor_id, TF_Status* status) {
+  std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
+      TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
+  TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
+  if (!status->status.ok()) return nullptr;
+  // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
+  TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
+  TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
+  auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
+  TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
+                      shared_name.size());
+  TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
+
+  // TODO: consider making this an unknown shape.
+  const int64_t* dims_ptr = nullptr;
+  int num_dims = 0;
+  TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
+                         /*num_values*/ 0, status);
+  if (!status->status.ok()) return nullptr;
+
+  int num_retvals = 1;
+  TFE_TensorHandle* queue = nullptr;
+  TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
+  if (!status->status.ok()) return nullptr;
+  CHECK_EQ(num_retvals, 1);
+
+  return queue;
+}
+
+static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
+                             TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
+                             TF_Status* status) {
+  TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+  TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+  if (!status->status.ok()) return;
+  TFE_OpAddInput(op, queue, status);
+  if (!status->status.ok()) return;
+  TFE_OpAddInput(op, tensor, status);
+  if (!status->status.ok()) return;
+  TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
+  TFE_OpSetAttrInt(op, "timeout_ms", -1);
+
+  int num_retvals = 0;
+  TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
+  if (!status->status.ok()) return;
+  CHECK_EQ(num_retvals, 0);
+}
+
+static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
+                                          TF_DataType inputType,
+                                          TFE_TensorHandle* queue,
+                                          TF_Status* status) {
+  TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+  TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+  if (!status->status.ok()) return nullptr;
+
+  TFE_OpAddInput(op, queue, status);
+  if (!status->status.ok()) return nullptr;
+  TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
+  TFE_OpSetAttrInt(op, "timeout_ms", -1);
+  TFE_TensorHandle* ret;
+  int num_retvals = 1;
+  TFE_Execute(op, &ret, &num_retvals, status);
+  if (!status->status.ok()) return nullptr;
+  CHECK_EQ(num_retvals, 1);
+  return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
+                                         TF_DataType inputType,
+                                         TF_Status* status) {
+  assert(session);
+  VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
+
+  auto ctx = TFE_CreateContextFromSession(session, status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+      ctx, TFE_DeleteContext);
+
+  TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+  return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+                                                TF_DataType inputType,
+                                                TF_Status* status) {
+  TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+
+  return ret;
+}
+
+void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+                            TFE_TensorHandle* tensor, TF_Status* status) {
+  assert(session);
+  VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+  auto ctx = TFE_CreateContextFromSession(session, status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+      ctx, TFE_DeleteContext);
+
+  TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+  TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+                                   TFE_TensorHandle* tensor,
+                                   TF_Status* status) {
+  VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+  TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+  TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
+                              TFE_TensorHandle* tensor, TF_Status* status) {
+  VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
+
+  auto ctx = TFE_CreateContextFromSession(session, status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+      ctx, TFE_DeleteContext);
+
+  TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+  if (!status->status.ok()) return;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
+}
+
+TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
+                                           TF_Status* status) {
+  VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
+
+  auto ctx = TFE_CreateContextFromSession(session, status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+      ctx, TFE_DeleteContext);
+
+  TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+  if (!status->status.ok()) return nullptr;
+  std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+      queue_deleter(queue, TFE_DeleteTensorHandle);
+
+  return createTFEDequeue(ctx, TF_VARIANT, queue, status);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 09d482d..522c91f 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -132,9 +132,48 @@
                                                  TF_Tensor* tensor,
                                                  TF_Status* status);
 
+// TODO: remove this API in favor of the next one.
 TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
     const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
 
+// Creates from `session` a new eager context to run a graph function or
+// sends/recvs, so that these concurrent TFE executions can share (via
+// `session` and its associated device mgr) the same set of fifo queue resource
+// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
+// graph function execution can access the same fifo queue resource handles
+// (associated with devices managed by the device manager, which can be obtained
+// from `session`).
+//
+// TODO: Remove this function once we migrate away from using session.
+TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
+    TF_Session* session, TF_Status* status);
+
+// TODO: Retire this API in favor of the next one.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
+    TF_Session* session, int tensor_id, TF_DataType inputType,
+    TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
+    TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
+                                                  int tensor_id,
+                                                  TFE_TensorHandle* tensor,
+                                                  TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
+    TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
+    TF_Status* status);
+
+// TODO: consider folding the 2 APIs below into the ones above.
+TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
+                                                    int tensor_id,
+                                                    TFE_TensorHandle* tensor,
+                                                    TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
+    TF_Session* session, int tensor_id, TF_Status* status);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index a2c5a42..f68f8a3 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/strings/base64.h"
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 77e3878..349d9bc 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -399,6 +399,19 @@
                         : d->name().c_str();
 }
 
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+    TFE_TensorHandle* h, TF_Status* status) {
+  if (h == nullptr || h->handle == nullptr) {
+    status->status = tensorflow::errors::InvalidArgument(
+        "The passed in handle is a nullptr");
+    return nullptr;
+  }
+
+  h->handle->Ref();
+
+  return new TFE_TensorHandle(h->handle);
+}
+
 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
   if (h == nullptr || h->handle == nullptr) {
     status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index eec2750..337447e 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -171,6 +171,12 @@
 TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
     TFE_TensorHandle* h, TF_Status* status);
 
+// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
+// with `h`. On success, `status` is set to OK. On failure, `status` reflects
+// the error and a nullptr is returned.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+    TFE_TensorHandle* h, TF_Status* status);
+
 // This function will block till the operation that produces `h` has
 // completed. The memory returned might alias the internal memory used by
 // TensorFlow. Hence, callers should not mutate this memory (for example by
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 7126227..5533102 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1528,4 +1528,29 @@
   TFE_DeleteContext(ctx);
   TF_DeleteStatus(status);
 }
+
+TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
+  TFE_TensorHandle* h = TestMatrixTensorHandle();
+  EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+
+  TFE_TensorHandle* h_shares_tensor =
+      TFE_TensorHandleCopySharingTensor(h, status.get());
+  ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+  TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
+  ASSERT_EQ(16, TF_TensorByteSize(t));
+  float data[4] = {0};
+  memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+  EXPECT_EQ(1.0, data[0]);
+  EXPECT_EQ(2.0, data[1]);
+  EXPECT_EQ(3.0, data[2]);
+  EXPECT_EQ(4.0, data[3]);
+  TF_DeleteTensor(t);
+
+  TFE_DeleteTensorHandle(h);
+  TFE_DeleteTensorHandle(h_shares_tensor);
+}
 }  // namespace
diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h
index a085e1d..0717e7d 100644
--- a/tensorflow/cc/framework/ops.h
+++ b/tensorflow/cc/framework/ops.h
@@ -150,7 +150,7 @@
     Initializer(const std::initializer_list<T>& v, const TensorShape& shape) {
       typedef typename RealType<T>::type RealT;
       Tensor t(DataTypeToEnum<RealT>::v(), shape);
-      if (t.NumElements() != v.size()) {
+      if (t.NumElements() != static_cast<int64>(v.size())) {
         status = errors::InvalidArgument(
             "Cannot construct a tensor with ", t.NumElements(),
             " from an initializer list with ", v.size(), " elements");
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 2b1ce34..b17bc65 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "absl/strings/str_replace.h"
 #include "absl/types/span.h"
@@ -31,7 +32,6 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 
 namespace tensorflow {
 namespace tfcompile {
@@ -135,12 +135,12 @@
     indices = "[0]";
   } else {
     for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
-      dim_vars.push_back(strings::StrCat("size_t dim", dim));
-      dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
-      indices += strings::StrCat("[dim", dim, "]");
+      dim_vars.push_back(absl::StrCat("size_t dim", dim));
+      dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
+      indices += absl::StrCat("[dim", dim, "]");
     }
   }
-  rewrites->push_back({"{{I}}", strings::StrCat(i)});
+  rewrites->push_back({"{{I}}", absl::StrCat(i)});
   rewrites->push_back({"{{TYPE}}", type});
   rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
   rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
@@ -194,7 +194,7 @@
         arg_data({{I}}))){{INDICES}};
   }
 )";
-    *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+    *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
     if (!config.feed(i).name().empty()) {
       *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
     }
@@ -235,7 +235,7 @@
         result_data({{I}}))){{INDICES}};
   }
 )";
-    *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+    *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
     if (!config.fetch(i).name().empty()) {
       *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
     }
@@ -304,8 +304,8 @@
                    string encoded_second_as_str =
                        encoded.second == ~0ULL
                            ? "~0ULL"
-                           : strings::StrCat(encoded.second, "ULL");
-                   return strings::StrCat(
+                           : absl::StrCat(encoded.second, "ULL");
+                   return absl::StrCat(
                        "::tensorflow::cpu_function_runtime::BufferInfo({",
                        encoded.first, "ULL, ", encoded_second_as_str, "})");
                  });
@@ -352,13 +352,13 @@
   // Create rewrite strings for namespace start and end.
   string ns_start;
   for (const string& n : opts.namespaces) {
-    ns_start += strings::StrCat("namespace ", n, " {\n");
+    ns_start += absl::StrCat("namespace ", n, " {\n");
   }
   ns_start += "\n";
   string ns_end("\n");
   for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
     const string& n = opts.namespaces[i];
-    ns_end += strings::StrCat("}  // end namespace ", n, "\n");
+    ns_end += absl::StrCat("}  // end namespace ", n, "\n");
   }
 
   // Generate metadata.
@@ -568,10 +568,10 @@
 )";
   // The replacement strategy is naive, but good enough for our purposes.
   const std::vector<std::pair<string, string>> rewrites = {
-      {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
-      {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
+      {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
+      {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
       {"{{ARG_NAMES_CODE}}", arg_names_code},
-      {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
+      {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
       {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
       {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
       {"{{CLASS}}", opts.class_name},
@@ -590,11 +590,11 @@
       {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
       {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
        metadata_result.program_shape_access_shim},
-      {"{{RESULT_INDEX}}", strings::StrCat(result_index)},
+      {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
       {"{{RESULT_NAMES_CODE}}", result_names_code},
-      {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
-      {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
-      {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
+      {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
+      {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
+      {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
       {"{{BUFFER_INFOS_AS_STRING}}",
        absl::StrJoin(buffer_infos_as_strings, ",\n")}};
   absl::StrReplaceAll(rewrites, header);
@@ -602,13 +602,13 @@
 }
 
 static string CreateUniqueIdentifier(const CodegenOpts& opts,
-                                     StringPiece suffix) {
+                                     absl::string_view suffix) {
   string result = "__tfcompile";
   for (const string& n : opts.namespaces) {
-    strings::StrAppend(&result, "_", n);
+    absl::StrAppend(&result, "_", n);
   }
 
-  strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
+  absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
   return result;
 }
 
@@ -678,7 +678,7 @@
   return Status::OK();
 }
 
-Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
+Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
   if (ident.empty()) {
     return errors::InvalidArgument("empty identifier: ", msg);
   }
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 83f2d3e..90410c4 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -19,9 +19,9 @@
 #include <string>
 #include <vector>
 
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/aot/compile.h"
 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 
 namespace tensorflow {
 namespace tfcompile {
@@ -96,7 +96,7 @@
 
 // ValidateCppIdent returns OK iff ident is a valid C++ identifier.  The msg is
 // appended to error messages.
-Status ValidateCppIdent(StringPiece ident, StringPiece msg);
+Status ValidateCppIdent(absl::string_view ident, absl::string_view msg);
 
 }  // namespace tfcompile
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index e3a53ed..bb288d2 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -19,11 +19,11 @@
 #include <vector>
 
 #include "absl/strings/match.h"
+#include "absl/strings/string_view.h"
 #include "llvm/Support/TargetSelect.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index f1e8e5c..3c32d53 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -38,11 +38,11 @@
 
 static void AddEmbeddedProtocolBufferToLlvmModule(
     llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
-    StringPiece unique_identifier, string* protobuf_array_symbol_name,
+    absl::string_view unique_identifier, string* protobuf_array_symbol_name,
     int64* protobuf_array_size) {
   string protobuf_array_contents = proto.SerializeAsString();
   *protobuf_array_symbol_name =
-      strings::StrCat(unique_identifier, "_protobuf_array_contents");
+      absl::StrCat(unique_identifier, "_protobuf_array_contents");
   *protobuf_array_size = protobuf_array_contents.size();
 
   llvm::Constant* protobuf_array_initializer =
@@ -55,9 +55,9 @@
       protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
 }
 
-static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
-                                      StringPiece protobuf_array_symbol_name,
-                                      int64 protobuf_array_size) {
+static string CreateCPPShimExpression(
+    absl::string_view qualified_cpp_protobuf_name,
+    absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) {
   string code =
       "[]() {\n"
       "    {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n"
@@ -68,9 +68,9 @@
   return absl::StrReplaceAll(
       code,
       {
-          {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)},
-          {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)},
-          {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)},
+          {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)},
+          {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)},
+          {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)},
       });
 }
 
@@ -93,7 +93,7 @@
 }
 
 static StatusOr<std::unique_ptr<llvm::TargetMachine>>
-GetTargetMachineFromTriple(StringPiece target_triple) {
+GetTargetMachineFromTriple(absl::string_view target_triple) {
   std::string error;
   std::string normalized_triple =
       llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple)));
@@ -110,7 +110,7 @@
 }
 
 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
-    StringPiece target_triple,
+    absl::string_view target_triple,
     absl::Span<const ProtobufToEmbed> protobufs_to_embed) {
   TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
                       GetTargetMachineFromTriple(target_triple));
@@ -135,8 +135,8 @@
           protobuf_to_embed.qualified_cpp_protobuf_name,
           protobuf_array_symbol_name, protobuf_array_size);
 
-      cpp_variable_decl = strings::StrCat("extern \"C\" char ",
-                                          protobuf_array_symbol_name, "[];");
+      cpp_variable_decl =
+          absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];");
     } else {
       cpp_shim = "nullptr";
     }
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index 4f940c0..cf5c04a 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -83,7 +83,7 @@
 // is stored in the object_file_data field in the returned
 // EmbeddedProtocolBuffers instance.
 StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
-    StringPiece target_triple,
+    absl::string_view target_triple,
     absl::Span<const ProtobufToEmbed> protobufs_to_embed);
 
 }  // namespace tfcompile
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 723e9be..7a0932d 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -67,7 +67,12 @@
         "test_graph_tfmatmulandadd.pb",
         "test_graph_tfsplits.pb",
     ],
-    cmd = "$(location :make_test_graphs) --out_dir $(@D)",
+    # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
+    # GPUs which might be present.  This is important because builds may run
+    # concurrently with tests, and tests need to be able to assume that they
+    # have control of the full GPU.
+    cmd = "CUDA_VISIBLE_DEVICES='' " +
+          "$(location :make_test_graphs) --out_dir $(@D)",
     tags = ["manual"],
     tools = [":make_test_graphs"],
 )
@@ -226,6 +231,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo_profile_printer",
         "//tensorflow/core:lib",
+        "//tensorflow/core:regexp_internal",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index dd2b151..7ac90fb 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -33,6 +33,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/regexp.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -543,7 +544,13 @@
   string hlo_profile_as_string =
       xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
                            /*clock_rate_ghz=*/1.0);
-  VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+  VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
+
+  // Strip away identifier details from the profile string to avoid this test
+  // being a change detector for xla internals. Identifiers such as '%dot.0.7'
+  // just become '%dot'.
+  RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
+  VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
 
   std::vector<string> hlo_profile_lines =
       absl::StrSplit(hlo_profile_as_string, '\n');
@@ -551,16 +558,14 @@
   auto header = HasSubstr("Execution profile for");
   auto total_cycles_profile_line = HasSubstr("[total]");
   auto dot_profile_line = HasSubstr(
-      "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
-      "%arg1.0.1)");
+      "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
   auto add_profile_line = HasSubstr(
-      "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
-      "%arg1.0.1)");
+      "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
   auto tuple_profile_line = HasSubstr(
-      "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
-      "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
-  auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
-  auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+      "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+      "f32[2,2]{1,0} %add)");
+  auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+  auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
 
   EXPECT_THAT(hlo_profile_lines,
               IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 326f73b..792b7fe 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -105,12 +105,18 @@
         freeze_file = freeze_name + ".pb"
 
         # First run tfcompile to generate the list of out_nodes.
+        #
+        # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we
+        # launch from using any GPUs which might be present.  This is important
+        # because builds may run concurrently with tests, and tests need to be
+        # able to assume that they have control of the full GPU.
         out_nodes_file = "out_nodes_" + freeze_name
         native.genrule(
             name = ("gen_" + out_nodes_file),
             srcs = [config],
             outs = [out_nodes_file],
-            cmd = ("$(location " + tfcompile_tool + ")" +
+            cmd = ("CUDA_VISIBLE_DEVICES='' " +
+                   "$(location " + tfcompile_tool + ")" +
                    " --config=$(location " + config + ")" +
                    " --dump_fetch_nodes > $@"),
             tools = [tfcompile_tool],
@@ -142,9 +148,12 @@
                 out_nodes_file,
             ] + freeze_saver_srcs,
             outs = [freeze_file],
-            cmd = ("$(location " +
-                   "//tensorflow/python/tools:freeze_graph)" +
-                   freeze_args),
+            cmd = (
+                "CUDA_VISIBLE_DEVICES='' " +
+                "$(location " +
+                "//tensorflow/python/tools:freeze_graph)" +
+                freeze_args
+            ),
             tools = ["//tensorflow/python/tools:freeze_graph"],
             tags = tags,
         )
@@ -177,16 +186,19 @@
             metadata_object_file,
             function_object_file,
         ],
-        cmd = ("$(location " + tfcompile_tool + ")" +
-               " --graph=$(location " + tfcompile_graph + ")" +
-               " --config=$(location " + config + ")" +
-               " --entry_point=" + ep +
-               " --cpp_class=" + cpp_class +
-               " --target_triple=" + target_llvm_triple() +
-               " --out_header=$(@D)/" + header_file +
-               " --out_metadata_object=$(@D)/" + metadata_object_file +
-               " --out_function_object=$(@D)/" + function_object_file +
-               " " + flags + " " + profiling_flag),
+        cmd = (
+            "CUDA_VISIBLE_DEVICES='' " +
+            "$(location " + tfcompile_tool + ")" +
+            " --graph=$(location " + tfcompile_graph + ")" +
+            " --config=$(location " + config + ")" +
+            " --entry_point=" + ep +
+            " --cpp_class=" + cpp_class +
+            " --target_triple=" + target_llvm_triple() +
+            " --out_header=$(@D)/" + header_file +
+            " --out_metadata_object=$(@D)/" + metadata_object_file +
+            " --out_function_object=$(@D)/" + function_object_file +
+            " " + flags + " " + profiling_flag
+        ),
         tools = [tfcompile_tool],
         visibility = visibility,
         testonly = testonly,
@@ -216,14 +228,17 @@
         outs = [
             session_module_pb,
         ],
-        cmd = ("$(location " + tfcompile_tool + ")" +
-               " --graph=$(location " + tfcompile_graph + ")" +
-               " --config=$(location " + config + ")" +
-               " --entry_point=" + ep +
-               " --cpp_class=" + cpp_class +
-               " --target_triple=" + target_llvm_triple() +
-               " --out_session_module=$(@D)/" + session_module_pb +
-               " " + flags),
+        cmd = (
+            "CUDA_VISIBLE_DEVICES='' " +
+            "$(location " + tfcompile_tool + ")" +
+            " --graph=$(location " + tfcompile_graph + ")" +
+            " --config=$(location " + config + ")" +
+            " --entry_point=" + ep +
+            " --cpp_class=" + cpp_class +
+            " --target_triple=" + target_llvm_triple() +
+            " --out_session_module=$(@D)/" + session_module_pb +
+            " " + flags
+        ),
         tools = [tfcompile_tool],
         visibility = visibility,
         testonly = testonly,
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index f3c44e9..b95b063 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -20,6 +20,7 @@
 
 #include "absl/strings/match.h"
 #include "absl/strings/str_join.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/aot/codegen.h"
 #include "tensorflow/compiler/aot/compile.h"
 #include "tensorflow/compiler/aot/flags.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
@@ -92,8 +92,9 @@
   // Write output files.
   Env* env = Env::Default();
   const std::vector<char>& obj = compile_result.aot->object_file_data();
-  TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
-                                       StringPiece(obj.data(), obj.size())));
+  TF_RETURN_IF_ERROR(
+      WriteStringToFile(env, flags.out_function_object,
+                        absl::string_view(obj.data(), obj.size())));
   CodegenOpts codegen_opts;
   codegen_opts.gen_name_to_index = flags.gen_name_to_index;
   codegen_opts.gen_program_shape = flags.gen_program_shape;
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index df81f3c..7d5db71 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -265,6 +265,7 @@
     srcs = ["jit_compilation_pass_registration.cc"],
     deps = [
         ":compilation_passes",
+        "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
         "//tensorflow/core:core_cpu_internal",
     ],
     alwayslink = 1,
@@ -395,6 +396,7 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/kernels:bounds_check",
+        "@com_google_absl//absl/algorithm:container",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -410,6 +412,7 @@
         "//tensorflow/core:graph",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/kernels:bounds_check",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
     ],
 )
@@ -479,6 +482,7 @@
         ":common",
         ":compilation_passes",
         ":xla_cluster_util",
+        ":xla_gpu_device",
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/cc:function_ops",
@@ -495,6 +499,8 @@
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler/optimizers/data:graph_utils",
+        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -566,6 +572,7 @@
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
         "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 82aa038..9128b48 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -154,7 +154,7 @@
                    std::back_inserter(operands_str),
                    [](Predicate* pred) { return pred->ToString(); });
 
-    return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
+    return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")");
   }
 
   Kind kind() const override { return Kind::kAnd; }
@@ -185,7 +185,7 @@
                    std::back_inserter(operands_str),
                    [](Predicate* pred) { return pred->ToString(); });
 
-    return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
+    return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")");
   }
 
   Kind kind() const override { return Kind::kOr; }
@@ -206,7 +206,7 @@
         operands_({operand}) {}
 
   string ToString() const override {
-    return strings::StrCat("~", operand()->ToString());
+    return absl::StrCat("~", operand()->ToString());
   }
 
   Kind kind() const override { return Kind::kNot; }
@@ -240,8 +240,8 @@
   Predicate* step() const { return operands_[1]; }
 
   string ToString() const override {
-    return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
-                           "}");
+    return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(),
+                        "}");
   }
 
   Kind kind() const override { return Kind::kAndRecurrence; }
@@ -267,7 +267,7 @@
         must_be_true_(must_be_true) {}
 
   string ToString() const override {
-    return must_be_true() ? strings::StrCat("*", tensor_id_.ToString())
+    return must_be_true() ? absl::StrCat("*", tensor_id_.ToString())
                           : tensor_id_.ToString();
   }
 
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 2788102..ae7a22f 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@
 #include <unordered_map>
 #include <vector>
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
 #include "tensorflow/compiler/jit/shape_inference_helpers.h"
@@ -45,7 +46,6 @@
 #include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/public/session_options.h"
 #include "tensorflow/core/public/version.h"
 #include "tensorflow/core/util/device_name_utils.h"
@@ -755,7 +755,7 @@
   if (inserted) {
     NodeDef arg_def;
     NodeDefBuilder builder(
-        strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
+        absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp);
     DataType dtype = edge->dst()->input_type(edge->dst_input());
     builder.Attr("T", dtype);
     builder.Attr("index", arg_index);
@@ -790,7 +790,7 @@
   if (inserted) {
     NodeDef ret_def;
     NodeDefBuilder builder(
-        strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
+        absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp);
     DataType dtype = src_node->output_type(src_slot);
     builder.Attr("T", dtype);
     builder.Attr("index", ret_index);
@@ -950,16 +950,15 @@
       }
 
       NodeDef host_compute_def;
-      NodeDefBuilder builder(strings::StrCat("outside_compilation_",
-                                             oc_subgraph_name, "_host_compute"),
+      NodeDefBuilder builder(absl::StrCat("outside_compilation_",
+                                          oc_subgraph_name, "_host_compute"),
                              kHostComputeOp);
       builder.Input(inputs);
       builder.Attr("Tinputs", input_dtypes);
       builder.Attr("Toutputs", output_dtypes);
       builder.Attr("ancestors", host_compute_ancestors);
-      builder.Attr("key",
-                   strings::StrCat("host_compute_channel_", subgraph_name, "_",
-                                   oc_subgraph_name));
+      builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name,
+                                       "_", oc_subgraph_name));
       builder.Attr("_outside_compilation_subgraph", oc_subgraph_name);
       Status s = builder.Finalize(&host_compute_def);
       if (!s.ok()) return s;
@@ -1017,8 +1016,7 @@
                                                   Graph* graph_out) {
   if (sequencer_ == nullptr) {
     NodeDef seq_def;
-    NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"),
-                           "NoOp");
+    NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
     builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
     builder.Device(device_);
     Status s = builder.Finalize(&seq_def);
@@ -1091,10 +1089,10 @@
 
   if (VLOG_IS_ON(1)) {
     VLOG(2) << "Build function def " << name;
-    dump_graph::DumpGraphToFile(
-        strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library);
-    dump_graph::DumpFunctionDefToFile(
-        strings::StrCat("encapsulate_fdef_", name), fdef);
+    dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name),
+                                *graph_, library);
+    dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name),
+                                      fdef);
   }
 
   if (!reuse_existing_functions || library->Find(name) == nullptr) {
@@ -1130,8 +1128,8 @@
     host_compute->AddAttr("shapes", shapes);
   } else {
     string inference_graph_name =
-        strings::StrCat("_outside_compilation_shape_inference_", subgraph_name,
-                        "_", outside_compilation_subgraph_name);
+        absl::StrCat("_outside_compilation_shape_inference_", subgraph_name,
+                     "_", outside_compilation_subgraph_name);
     FunctionDef fdef;
     TF_RETURN_IF_ERROR(
         GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
@@ -1155,10 +1153,10 @@
   if (VLOG_IS_ON(1)) {
     VLOG(2) << "Replace function def " << name;
     dump_graph::DumpGraphToFile(
-        strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
+        absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
         library);
     dump_graph::DumpFunctionDefToFile(
-        strings::StrCat("replace_encapsulate_fdef_", name), fdef);
+        absl::StrCat("replace_encapsulate_fdef_", name), fdef);
   }
 
   TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
@@ -1186,8 +1184,7 @@
   GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
   NodeDef key_def;
   NodeDefBuilder builder(
-      strings::StrCat(call_node_def_.name(), "_key_placeholder"),
-      "Placeholder");
+      absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder");
   builder.Attr("dtype", DT_STRING);
   builder.Attr("shape", shape_proto);
   builder.Attr("_host_compute_call_node", call_node_def_.name());
@@ -1221,16 +1218,16 @@
   }
 
   NodeDef recv_def;
-  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
-                                         "_", oc_subgraph_name, "_recv"),
+  NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
+                                      "_", oc_subgraph_name, "_recv"),
                          kRecvAtHostOp);
   builder.Device(device_);
   builder.Attr("Toutputs", dtypes);
   // The correct device_ordinal will be inserted during replication in a
   // subsequent rewrite.
   builder.Attr("device_ordinal", 0);
-  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
-                                      "_", oc_subgraph_name));
+  builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
+                                   oc_subgraph_name));
   builder.Attr(group_attribute, subgraph_name);
   builder.Attr(outside_compilation_attribute, oc_subgraph_name);
   builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING);
@@ -1276,13 +1273,13 @@
   }
 
   NodeDef send_def;
-  NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
-                                         "_", oc_subgraph_name, "_send"),
+  NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name,
+                                      "_", oc_subgraph_name, "_send"),
                          kSendFromHostOp);
   builder.Device(device_);
   builder.Attr("Tinputs", dtypes);
-  builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
-                                      "_", oc_subgraph_name));
+  builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_",
+                                   oc_subgraph_name));
   // The correct device_ordinal will be inserted during replication in a
   // subsequent rewrite.
   builder.Attr("device_ordinal", 0);
@@ -1516,7 +1513,7 @@
     // Dump subgraphs.
     for (auto& entry : subgraphs_) {
       dump_graph::DumpGraphToFile(
-          strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
+          absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
           *entry.second.GetGraph(), library);
     }
   }
@@ -2052,7 +2049,7 @@
   struct SubgraphAndClusterHash {
     inline std::size_t operator()(const SubgraphAndCluster& v) const {
       return hash<string>()(
-          strings::StrCat(v.subgraph, v.outside_compilation_cluster));
+          absl::StrCat(v.subgraph, v.outside_compilation_cluster));
     }
   };
 
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 7bc0ef0..4995809 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -16,6 +16,7 @@
 #include <memory>
 #include <utility>
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
 
 #include "absl/strings/match.h"
@@ -48,7 +49,7 @@
   FunctionDef* fdef = library->add_function();
   TF_RETURN_IF_ERROR(GraphToFunctionDef(
       *graph,
-      strings::StrCat("_outside_compilation_shape_inference_", name_suffix),
+      absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
       fdef));
   return Status::OK();
 }
@@ -65,18 +66,18 @@
     const auto iter = b.find(elt_a.first);
     if (iter == b.end()) {
       if (diff) {
-        *diff = strings::StrCat(
-            map_name, " expected: contains element with key '",
-            key_to_string(elt_a.first), "' got: map has no such element");
+        *diff = absl::StrCat(map_name, " expected: contains element with key '",
+                             key_to_string(elt_a.first),
+                             "' got: map has no such element");
       }
       return false;
     }
     if (!compare(elt_a.first, elt_a.second, iter->second)) {
       if (diff) {
-        *diff = strings::StrCat(map_name, " expected: element with key '",
-                                key_to_string(elt_a.first), "' has value '",
-                                value_to_string(elt_a.second), "' got: '",
-                                value_to_string(iter->second), "'");
+        *diff = absl::StrCat(map_name, " expected: element with key '",
+                             key_to_string(elt_a.first), "' has value '",
+                             value_to_string(elt_a.second), "' got: '",
+                             value_to_string(iter->second), "'");
       }
       return false;
     }
@@ -85,9 +86,9 @@
     const auto iter = a.find(elt_b.first);
     if (iter == a.end()) {
       if (diff) {
-        *diff = strings::StrCat(map_name, " got: contains element with key '",
-                                key_to_string(elt_b.first),
-                                "' expected: map has no such element");
+        *diff = absl::StrCat(map_name, " got: contains element with key '",
+                             key_to_string(elt_b.first),
+                             "' expected: map has no such element");
       }
       return false;
     }
@@ -99,25 +100,25 @@
                           const string& diff_preamble, string* diff) {
   if (a.op() != b.op()) {
     if (diff) {
-      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
-                              ", expected op '", a.op(), "' got '", b.op());
+      *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                           ", expected op '", a.op(), "' got '", b.op());
     }
     return false;
   }
   if (a.device() != b.device()) {
     if (diff) {
-      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
-                              ", expected device '", a.device(), "' got '",
-                              b.device());
+      *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                           ", expected device '", a.device(), "' got '",
+                           b.device());
     }
     return false;
   }
   if (a.input_size() != b.input_size()) {
     if (diff) {
-      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
-                              ", expected ", a.input_size(), " inputs got ",
-                              b.input_size(), " expected:\n", a.DebugString(),
-                              "\ngot:\n", b.DebugString());
+      *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                           ", expected ", a.input_size(), " inputs got ",
+                           b.input_size(), " expected:\n", a.DebugString(),
+                           "\ngot:\n", b.DebugString());
     }
     return false;
   }
@@ -127,10 +128,10 @@
     if (absl::StartsWith(a.input(i), "^")) {
       if (!absl::StartsWith(b.input(i), "^")) {
         if (diff) {
-          *diff = strings::StrCat(
-              diff_preamble, " mismatch for node ", a.name(), " input ", i,
-              ", expected control input ", a.input(i), " got ", b.input(i),
-              " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString());
+          *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                               " input ", i, ", expected control input ",
+                               a.input(i), " got ", b.input(i), " expected:\n",
+                               a.DebugString(), "\ngot:\n", b.DebugString());
         }
         return false;
       }
@@ -138,19 +139,19 @@
       control_input_b.insert(b.input(i));
     } else if (a.input(i) != b.input(i)) {
       if (diff) {
-        *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
-                                " input ", i, ", expected ", a.input(i),
-                                " got ", b.input(i), " expected:\n",
-                                a.DebugString(), "\ngot:\n", b.DebugString());
+        *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                             " input ", i, ", expected ", a.input(i), " got ",
+                             b.input(i), " expected:\n", a.DebugString(),
+                             "\ngot:\n", b.DebugString());
       }
       return false;
     }
   }
   if (control_input_a != control_input_b) {
     if (diff) {
-      *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
-                              " control inputs differ expected:\n",
-                              a.DebugString(), "\ngot:\n", b.DebugString());
+      *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
+                           " control inputs differ expected:\n",
+                           a.DebugString(), "\ngot:\n", b.DebugString());
     }
     return false;
   }
@@ -170,18 +171,17 @@
           return av.DebugString() == bv.DebugString();
         }
       },
-      strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
-      diff);
+      absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
 }
 
 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
                       string* diff) {
   if (a.signature().DebugString() != b.signature().DebugString()) {
     if (diff) {
-      *diff = strings::StrCat("Signature mismatch for function ",
-                              a.signature().name(), ", expected:\n",
-                              a.signature().DebugString(), "\ngot:\n",
-                              b.signature().DebugString());
+      *diff =
+          absl::StrCat("Signature mismatch for function ", a.signature().name(),
+                       ", expected:\n", a.signature().DebugString(), "\ngot:\n",
+                       b.signature().DebugString());
     }
     return false;
   }
@@ -191,7 +191,7 @@
           [](const string& key, const AttrValue& av, const AttrValue& bv) {
             return av.DebugString() == bv.DebugString();
           },
-          strings::StrCat("attr mismatch for function ", a.signature().name()),
+          absl::StrCat("attr mismatch for function ", a.signature().name()),
           diff)) {
     return false;
   }
@@ -201,7 +201,7 @@
           [](const string& key, const string& av, const string& bv) {
             return av == bv;
           },
-          strings::StrCat("ret mismatch for function ", a.signature().name()),
+          absl::StrCat("ret mismatch for function ", a.signature().name()),
           diff)) {
     return false;
   }
@@ -211,7 +211,7 @@
       if (a.node_def(i).name() == b.node_def(j).name()) {
         if (!EqualFunctionNodeDef(
                 a.node_def(i), b.node_def(j),
-                strings::StrCat("Function ", a.signature().name()), diff)) {
+                absl::StrCat("Function ", a.signature().name()), diff)) {
           return false;
         }
         found = true;
@@ -220,9 +220,9 @@
     }
     if (!found) {
       if (diff) {
-        *diff = strings::StrCat("Function ", a.signature().name(),
-                                ", expected: has node '", a.node_def(i).name(),
-                                "' got: no node of that name");
+        *diff = absl::StrCat("Function ", a.signature().name(),
+                             ", expected: has node '", a.node_def(i).name(),
+                             "' got: no node of that name");
       }
       return false;
     }
@@ -237,9 +237,9 @@
     }
     if (!found) {
       if (diff) {
-        *diff = strings::StrCat("Function ", a.signature().name(),
-                                ", got: has node '", b.node_def(i).name(),
-                                "' expected: no node of that name");
+        *diff = absl::StrCat("Function ", a.signature().name(),
+                             ", got: has node '", b.node_def(i).name(),
+                             "' expected: no node of that name");
       }
       return false;
     }
@@ -258,8 +258,8 @@
     auto it = actual_index.find(expected_function.signature().name());
     if (it == actual_index.end()) {
       if (diff) {
-        *diff = strings::StrCat("Did not find expected function '",
-                                expected_function.signature().name(), "'");
+        *diff = absl::StrCat("Did not find expected function '",
+                             expected_function.signature().name(), "'");
       }
       return false;
     }
@@ -269,9 +269,9 @@
 
   if (!actual_index.empty()) {
     if (diff != nullptr) {
-      *diff = strings::StrCat("Found unexpected function '",
-                              actual_index.begin()->second->signature().name(),
-                              "'");
+      *diff =
+          absl::StrCat("Found unexpected function '",
+                       actual_index.begin()->second->signature().name(), "'");
     }
     return false;
   }
@@ -420,10 +420,9 @@
                  const string& oc_cluster, absl::Span<const DataType> dtypes,
                  const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
-  string key =
-      strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
-  string name = strings::StrCat("outside_compilation_", cluster, "_",
-                                oc_cluster, "_recv");
+  string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+  string name =
+      absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv");
   NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
                            "_XlaRecvAtHost", opts.op_registry());
   node_builder.Input(std::move(key_input));
@@ -440,10 +439,9 @@
                    const std::vector<ops::NodeOut>& inputs,
                    const GraphDefBuilder::Options& opts) {
   if (opts.HaveError()) return nullptr;
-  string key =
-      strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
-  string name = strings::StrCat("outside_compilation_", cluster, "_",
-                                oc_cluster, "_send");
+  string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster);
+  string name =
+      absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send");
   NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
                            "_XlaSendFromHost", opts.op_registry());
   node_builder.Input(inputs);
@@ -682,8 +680,8 @@
   for (const Edge* edge : graph.edges()) {
     if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
     edges.emplace_back(
-        strings::StrCat(edge->src()->name(), ":", edge->src_output()),
-        strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
+        absl::StrCat(edge->src()->name(), ":", edge->src_output()),
+        absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
   }
   std::sort(edges.begin(), edges.end());
   return edges;
diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD
index 676f71a..8212956 100644
--- a/tensorflow/compiler/jit/graphcycles/BUILD
+++ b/tensorflow/compiler/jit/graphcycles/BUILD
@@ -14,6 +14,7 @@
     hdrs = ["graphcycles.h"],
     deps = [
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/container:inlined_vector",
     ],
 )
 
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
index 805bbc6..756377b 100644
--- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
@@ -34,7 +34,7 @@
 #include <algorithm>
 #include <unordered_set>
 
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
@@ -44,7 +44,7 @@
 typedef std::unordered_set<int32> NodeSet;
 template <typename T>
 struct VecStruct {
-  typedef gtl::InlinedVector<T, 4> type;
+  typedef absl::InlinedVector<T, 4> type;
 };
 template <typename T>
 using Vec = typename VecStruct<T>::type;
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index c37b611..5dcf754 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -21,6 +21,18 @@
 
 namespace tensorflow {
 
+// PRE_PLACEMENT passes:
+
+// from
+// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
+// FunctionalizeControlFlowPass: 27
+//
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (XlaIf/XlaWhile). Following passes must
+// handle those FunctionDef correctly.
+
+// POST_REWRITE_FOR_EXEC passes:
 REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
                       MarkForCompilationPass);
 
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 5b6692f..07c5b23 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -29,18 +29,6 @@
 )
 
 cc_library(
-    name = "parallel_check_op_flags",
-    srcs = ["parallel_check_op_flags.cc"],
-    hdrs = ["parallel_check_op_flags.h"],
-    deps =
-        [
-            "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
-            "//tensorflow/core:framework_internal",
-            "//tensorflow/core:lib",
-        ],
-)
-
-cc_library(
     name = "xla_device_flags",
     srcs = ["xla_device_flags.cc"],
     hdrs = ["xla_device_flags.h"],
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
deleted file mode 100644
index a61694b..0000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <mutex>
-#include <vector>
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static ParallelCheckOpFlags* flags;
-static std::vector<Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags.  Called via call_once(&flags_init,...).
-static void AllocateFlags() {
-  flags = new ParallelCheckOpFlags;
-  flags->parallel_check_failfast = true;
-  flags->parallel_check_atol = "1e-5";
-  flags->parallel_check_rtol = "1e-5";
-  flag_list = new std::vector<Flag>({
-      Flag("parallel_check_failfast", &flags->parallel_check_failfast,
-           "Fail immediately on first parallel-check comparison error."),
-      Flag("parallel_check_atol", &flags->parallel_check_atol,
-           "Absolute error tolerance for parallel-check comparison."),
-      Flag("parallel_check_rtol", &flags->parallel_check_rtol,
-           "Relative error tolerance for parallel-check comparison."),
-  });
-  xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
-  std::call_once(flags_init, &AllocateFlags);
-  append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags() {
-  std::call_once(flags_init, &AllocateFlags);
-  return flags;
-}
-
-}  // namespace legacy_flags
-}  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
deleted file mode 100644
index 156a2a2..0000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with the XLA bridge's
-// parallel_check_op module.
-typedef struct {
-  bool parallel_check_failfast;  // Fail immediately on first parallel-check
-                                 // comparison error.
-  string parallel_check_atol;    // Absolute error tolerance for parallel-check
-                                 // comparison.
-  string parallel_check_rtol;    // Relative error tolerance for parallel-check
-                                 // comparison.
-} ParallelCheckOpFlags;
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags();
-
-}  // namespace legacy_flags
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 4e4abad..e6cc6e5 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -43,7 +43,6 @@
 #include "tensorflow/core/kernels/bounds_check.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/public/version.h"
 
@@ -444,7 +443,7 @@
         !registration->requires_compilation) {
       const OpDef* op_def;
       TF_RETURN_IF_ERROR(
-          OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def));
+          graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
       if (op_def->is_stateful()) {
         // We need to be able to constant fold the nodes in
         // compile_time_const_nodes given constant inputs (required by XLA) and
@@ -617,7 +616,7 @@
 }
 
 static string RatioToString(int numerator, int denominator) {
-  return strings::Printf("%d / %d (%.2f%%)", numerator, denominator,
+  return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator,
                          (100.0 * numerator) / denominator);
 }
 
@@ -626,14 +625,14 @@
     return;
   }
 
-  std::map<StringPiece, int> cluster_name_to_size;
-  std::map<StringPiece, std::map<StringPiece, int>>
+  std::map<absl::string_view, int> cluster_name_to_size;
+  std::map<absl::string_view, std::map<absl::string_view, int>>
       cluster_name_to_op_histogram;
-  std::map<StringPiece, int> unclustered_op_histogram;
+  std::map<absl::string_view, int> unclustered_op_histogram;
   int clustered_node_count = 0;
 
   for (Node* n : g.nodes()) {
-    absl::optional<StringPiece> cluster_name = GetXlaClusterForNode(*n);
+    absl::optional<absl::string_view> cluster_name = GetXlaClusterForNode(*n);
     if (cluster_name) {
       clustered_node_count++;
       cluster_name_to_size[*cluster_name]++;
@@ -650,7 +649,7 @@
           << RatioToString(clustered_node_count, g.num_nodes());
 
   for (const auto& cluster_name_size_pair : cluster_name_to_size) {
-    StringPiece cluster_name = cluster_name_size_pair.first;
+    absl::string_view cluster_name = cluster_name_size_pair.first;
     int size = cluster_name_size_pair.second;
     VLOG(2) << "  " << cluster_name << " "
             << RatioToString(size, g.num_nodes());
@@ -670,14 +669,15 @@
   }
 
   struct EdgeInfo {
-    StringPiece node_name;
-    absl::optional<StringPiece> cluster_name;
+    absl::string_view node_name;
+    absl::optional<absl::string_view> cluster_name;
 
-    StringPiece GetClusterName() const {
+    absl::string_view GetClusterName() const {
       return cluster_name ? *cluster_name : "[none]";
     }
 
-    std::pair<StringPiece, absl::optional<StringPiece>> AsPair() const {
+    std::pair<absl::string_view, absl::optional<absl::string_view>> AsPair()
+        const {
       return {node_name, cluster_name};
     }
 
@@ -686,19 +686,21 @@
     }
   };
 
-  using EdgeInfoMap = std::map<StringPiece, std::map<EdgeInfo, int64>>;
+  using EdgeInfoMap = std::map<absl::string_view, std::map<EdgeInfo, int64>>;
 
   EdgeInfoMap incoming_edge_infos;
   EdgeInfoMap outgoing_edge_infos;
 
-  std::set<StringPiece> cluster_names_to_print;
+  std::set<absl::string_view> cluster_names_to_print;
 
   for (const Edge* e : g.edges()) {
     const Node* from = e->src();
-    absl::optional<StringPiece> from_cluster_name = GetXlaClusterForNode(*from);
+    absl::optional<absl::string_view> from_cluster_name =
+        GetXlaClusterForNode(*from);
 
     const Node* to = e->dst();
-    absl::optional<StringPiece> to_cluster_name = GetXlaClusterForNode(*to);
+    absl::optional<absl::string_view> to_cluster_name =
+        GetXlaClusterForNode(*to);
 
     if (to_cluster_name == from_cluster_name) {
       continue;
@@ -721,9 +723,9 @@
     VLOG(2) << "   [none]";
   }
 
-  auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name,
+  auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name,
                                              const EdgeInfoMap& edge_info_map,
-                                             StringPiece desc) {
+                                             absl::string_view desc) {
     auto it = edge_info_map.find(cluster_name);
     if (it != edge_info_map.end()) {
       VLOG(2) << "  " << it->second.size() << " " << desc << " edges";
@@ -737,7 +739,7 @@
     }
   };
 
-  for (StringPiece cluster_name : cluster_names_to_print) {
+  for (absl::string_view cluster_name : cluster_names_to_print) {
     VLOG(2) << " ** Cluster " << cluster_name;
     print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos,
                                     "incoming");
@@ -966,7 +968,7 @@
       string& name = cluster_names[cluster];
 
       if (name.empty()) {
-        name = strings::StrCat("cluster_", cluster_sequence_num++);
+        name = absl::StrCat("cluster_", cluster_sequence_num++);
       }
       n->AddAttr(kXlaClusterAttr, name);
       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 807ab51..c59770a 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
 
+#include "absl/memory/memory.h"
 #include "absl/strings/match.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/array_ops.h"
@@ -633,7 +634,7 @@
   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
   Scope root = Scope::NewRootScope().ExitOnError();
   {
-    auto BuildNoopNode = [](StringPiece name, Graph* graph) {
+    auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
       NodeDefBuilder builder(name, "NoOp");
       NodeDef def;
       TF_CHECK_OK(builder.Finalize(&def));
@@ -847,5 +848,51 @@
   EXPECT_EQ(clusters["shape"], "");
 }
 
+TEST(XlaCompilationTest, RandomShapeWithFunc) {
+  Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
+
+  FunctionDefLibrary flib_def;
+  FunctionDef func = FunctionDefHelper::Create(
+      /*function_name=*/"Stateful_func", /*in_def=*/{},
+      /*out_def=*/{"out: int32"},
+      /*attr_def*/
+      {}, /*node_def=*/
+      {FunctionDefHelper::Const("shape_shape", 2),
+       FunctionDefHelper::Const("minval", 1),
+       FunctionDefHelper::Const("maxval", 20),
+       {{"shape"},
+        "RandomUniformInt",
+        {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
+        {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
+      /*ret_def=*/{{"out", "shape:output:0"}});
+
+  func.mutable_signature()->set_is_stateful(true);
+  *flib_def.add_function() = std::move(func);
+  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+  NodeDef call_node;
+  call_node.set_name("fn_call");
+  call_node.set_op("Stateful_func");
+  Status status;
+  Node* call = root.graph()->AddNode(call_node, &status);
+  TF_ASSERT_OK(status);
+
+  Output shape = Output(call, 0);
+  Output reshape_input =
+      ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+                       ops::Placeholder::Shape(TensorShape({500, 500})));
+  Output reshape =
+      ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(root.ToGraph(graph.get()));
+  auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
+                                                          flib_def);
+  TF_ASSERT_OK(
+      MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
+
+  std::unordered_map<string, string> clusters = GetClusters(*graph);
+  EXPECT_EQ(clusters["fn_call"], "");
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index a8f09bf..10fc9e8 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -14,7 +14,11 @@
 ==============================================================================*/
 
 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/framework/memory_types.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
@@ -30,7 +34,7 @@
   MemoryTypeVector input_mtypes, output_mtypes;
 
   for (Node* n : post_order) {
-    absl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
+    absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
     if (!from_cluster) {
       continue;
     }
@@ -79,7 +83,7 @@
       // Check if `dst` is in a different cluster, unclustered, or about to be
       // partially declustered (here we rely on the post-order traversal order).
       // If yes, decluster `n` to avoid the device-to-host memcpy.
-      absl::optional<StringPiece> dst_cluster =
+      absl::optional<absl::string_view> dst_cluster =
           result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst);
       if (from_cluster != dst_cluster) {
         CHECK(result->insert(n).second);
@@ -91,15 +95,16 @@
 }
 
 Status PartiallyDeclusterNode(Graph* graph, Node* n) {
-  StringPiece cluster_name = *GetXlaClusterForNode(*n);
-  gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
+  absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+  absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
   for (const Edge* out_edge : n->out_edges()) {
     if (out_edge->IsControlEdge()) {
       continue;
     }
 
     Node* dst = out_edge->dst();
-    absl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
+    absl::optional<absl::string_view> dst_cluster_name =
+        GetXlaClusterForNode(*dst);
     if (dst_cluster_name != cluster_name) {
       out_edges_to_clone.push_back(out_edge);
     }
@@ -108,7 +113,7 @@
   CHECK(!out_edges_to_clone.empty()) << n->DebugString();
 
   NodeDef ndef = n->def();
-  ndef.set_name(strings::StrCat(n->name(), "/declustered"));
+  ndef.set_name(absl::StrCat(n->name(), "/declustered"));
   RemoveFromXlaCluster(&ndef);
   Status s;
   Node* cloned_node = graph->AddNode(ndef, &s);
@@ -128,30 +133,47 @@
 
   return Status::OK();
 }
-}  // namespace
 
-Status PartiallyDeclusterPass::Run(
-    const GraphOptimizationPassOptions& options) {
-  // NB!  In this pass we assume the only XLA-auto-clusterable operations that
-  // may have side effects are resource variable operations so we don't cluster
-  // those.  The pass will have to be updated if this assumption becomes
-  // invalid.
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
 
-  Graph* graph = options.graph->get();
-
+// Clones nodes to outside their cluster to avoid device-to-host copies.  For
+// instance, converts this:
+//
+//         .....
+//           |
+//           v
+//      A_Clustered ====> C_Unclustered
+//           |
+//           v
+//      B_Clustered
+//
+// to:
+//
+//         .....
+//          | |
+//          | +-------------+
+//          |               |
+//          v               v
+//      A_Clustered   A_Unclustered ====> C_Unclustered
+//           |
+//           v
+//      B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
   // When deciding whether to decluster a particular node, we base our decision
   // on if we've decided that some of its consumers have to be declustered too.
   // Iterating the graph in post-order guarantees that consumers have been
   // visited before producers.
   std::vector<Node*> post_order;
   GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
-               /*edge_filter=*/[](const Edge& edge) {
-                 return !edge.src()->IsNextIteration();
-               });
+               /*edge_filter=*/NotBackedge);
 
   gtl::FlatSet<Node*> nodes_to_partially_decluster;
-  TF_RETURN_IF_ERROR(FindNodesToDecluster(
-      **options.graph, &nodes_to_partially_decluster, post_order));
+  TF_RETURN_IF_ERROR(
+      FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
 
   if (VLOG_IS_ON(3)) {
     for (Node* n : post_order) {
@@ -168,10 +190,133 @@
   }
 
   nodes_to_partially_decluster.clear();
-  TF_RETURN_IF_ERROR(FindNodesToDecluster(
-      **options.graph, &nodes_to_partially_decluster, post_order));
+  TF_RETURN_IF_ERROR(
+      FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
   CHECK(nodes_to_partially_decluster.empty());
 
   return Status::OK();
 }
+
+bool IsIntraClusterEdge(const Edge& edge) {
+  absl::optional<absl::string_view> src_cluster_name =
+      GetXlaClusterForNode(*edge.src());
+  absl::optional<absl::string_view> dst_cluster_name =
+      GetXlaClusterForNode(*edge.dst());
+  return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
+}
+
+Status MustCompileNode(const Node* n, bool* result) {
+  DeviceType device_type("");
+  TF_RETURN_IF_ERROR(
+      DeviceToDeviceType(n->assigned_device_name(), &device_type));
+
+  const XlaOpRegistry::DeviceRegistration* registration;
+  if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+    *result = false;
+  } else {
+    *result = registration->requires_compilation;
+  }
+
+  return Status::OK();
+}
+
+// Declusters nodes to reduce the number of times we think we need to recompile
+// a TensorFlow graph.
+//
+// Abstractly, if we have a cluster of this form:
+//
+//   x0 = arg0
+//   x1 = arg1
+//     ...
+//   shape = f(x0, x1, ...)
+//   result = Reshape(input=<something>, new_shape=shape)
+//
+// then pulling `f` out of the cluster may reduce the number of compilations and
+// will never increase the number of compilations.
+//
+// We may reduce the number of compilations if f is many to one.  For instance
+// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
+// compilations if f is in the cluster but only one compilation if f is outside
+// the cluster.
+//
+// Declustering f will increase the number of compilations only if f is a
+// one-to-many "function" i.e. isn't a function at all.  RNG is one possible
+// example, depending on how we look at it.  But we never create clusters where
+// such f's would be marked as must-be-constant.
+//
+// We assume here that the extra repeated (repeated compared to a clustered f
+// where it will always be constant folded) host-side computation of f does not
+// regress performance in any significant manner.  We will have to revisit this
+// algorith with a more complex cost model if this assumption turns out to be
+// incorrect.
+Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+  std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
+  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+      *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
+
+  std::vector<Node*> rpo;
+  GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
+                      /*edge_filter=*/NotBackedge);
+  for (Node* n : rpo) {
+    if (!compile_time_const_nodes[n->id()]) {
+      continue;
+    }
+
+    absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+    bool node_on_cluster_edge =
+        absl::c_all_of(n->in_edges(), [&](const Edge* e) {
+          absl::optional<absl::string_view> incoming_cluster =
+              GetXlaClusterForNode(*e->src());
+          return !incoming_cluster || *incoming_cluster != cluster_name;
+        });
+
+    // We don't want to decluster F in a graph like
+    //
+    //   Input -> OP -> Shape -> F -> Reshape
+    //
+    // Doing so will break up the cluster.  Even if we were okay with breaking
+    // up the cluster we will at least have to relabel the two clusters to have
+    // different cluster names.
+    //
+    // We may want to revisit this in the future: we may have cases where OP is
+    // a small computation that does not benefit from XLA while XLA can optimize
+    // everything that follows the Reshape.  In these cases it may be wise to
+    // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
+    // function.
+    //
+    // Note that we do do the right thing for graphs like:
+    //
+    //   Input -> F0 -> F1 -> Reshape
+    //
+    // Since we iterate in RPO, we'll first encounter F0, decluster it, then
+    // encounter F1, decluster it and so on.
+    if (node_on_cluster_edge) {
+      bool must_compile_node;
+      TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
+      if (!must_compile_node) {
+        VLOG(3) << "Declustering must-be-constant node " << n->name();
+        RemoveFromXlaCluster(n);
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
+}  // namespace
+
+Status PartiallyDeclusterPass::Run(
+    const GraphOptimizationPassOptions& options) {
+  // NB!  In this pass we assume the only XLA-auto-clusterable operations that
+  // may have side effects are resource variable operations so we don't cluster
+  // those.  The pass will have to be updated if this assumption becomes
+  // invalid.
+
+  Graph* graph = options.graph->get();
+
+  TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
+  TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+
+  return Status::OK();
+}
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
index 6949b50..cfc4ddb 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.h
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -20,34 +20,11 @@
 
 namespace tensorflow {
 
-// Clones nodes from within a cluster to outside the cluster if profitable.
+// Clones or moves nodes from within a cluster to outside the cluster if
+// profitable.  There are two reasons why we do this:
 //
-// Today this only clones to avoid device-to-host copies, but in the future we
-// may consider other reasons to clone.  For instance, we convert this:
-//
-//         .....
-//           |
-//           v
-//      A_Clustered ====> C_Unclustered
-//           |
-//           v
-//      B_Clustered
-//
-// to:
-//
-//         .....
-//          | |
-//          | +-------------+
-//          |               |
-//          v               v
-//      A_Clustered   A_Unclustered ====> C_Unclustered
-//           |
-//           v
-//      B_Clustered
-//
-// where the ===> arrow has a hostmem source and destination and would entail a
-// device to host copy if the source and destination were not in the same XLA
-// cluster.
+//  - Reducing device-to-host copies.
+//  - Reducing the number of XLA recompilations.
 class PartiallyDeclusterPass : public GraphOptimizationPass {
  public:
   Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index f61a955..35872da 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
 
+#include "absl/memory/memory.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/array_ops.h"
 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
@@ -31,6 +32,7 @@
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -82,7 +84,9 @@
   // Assign all nodes to the CPU device.
   static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
   for (Node* n : (*graph)->nodes()) {
-    n->set_assigned_device_name(kCpuDevice);
+    if (n->assigned_device_name().empty()) {
+      n->set_assigned_device_name(kCpuDevice);
+    }
   }
 
   GraphOptimizationPassOptions opt_options;
@@ -91,8 +95,8 @@
   return pass.Run(opt_options);
 }
 
-const Node* FindNodeByName(const Graph& graph, const string& name) {
-  for (const Node* node : graph.nodes()) {
+Node* FindNodeByName(const Graph& graph, const string& name) {
+  for (Node* node : graph.nodes()) {
     if (node->name() == name) {
       return node;
     }
@@ -279,5 +283,128 @@
             "ClusteredProducer0/declustered");
   EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
 }
+
+void AddToCluster(absl::Span<Node* const> nodes,
+                  absl::string_view cluster_name) {
+  for (Node* n : nodes) {
+    n->AddAttr(kXlaClusterAttr, string(cluster_name));
+  }
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+  Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+                                          DT_FLOAT, ops::Placeholder::Attrs{});
+  Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+  AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+  auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+  TF_ASSERT_OK(s.ToGraph(graph.get()));
+  TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+  const Node* n = FindNodeByName(*graph, "shape");
+  ASSERT_NE(n, nullptr);
+
+  EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
+                                    ops::Placeholder::Attrs{});
+  Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
+  Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
+
+  Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
+
+  Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+                                          DT_FLOAT, ops::Placeholder::Attrs{});
+  Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+  AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
+               "cluster_0");
+
+  std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+  TF_ASSERT_OK(s.ToGraph(graph.get()));
+  TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+  const Node* n = FindNodeByName(*graph, "shape");
+  ASSERT_NE(n, nullptr);
+
+  EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+}
+
+TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+  Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+                                          DT_FLOAT, ops::Placeholder::Attrs{});
+  Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+  AddToCluster({reshape.node()}, "cluster_0");
+  AddToCluster({shape.node()}, "cluster_1");
+
+  std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+  TF_ASSERT_OK(s.ToGraph(graph.get()));
+  TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+  const Node* n = FindNodeByName(*graph, "shape");
+  ASSERT_NE(n, nullptr);
+
+  EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+                                    ops::Placeholder::Attrs{});
+  Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+  Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+                                          DT_FLOAT, ops::Placeholder::Attrs{});
+  Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+  AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+  std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+  TF_ASSERT_OK(s.ToGraph(graph.get()));
+
+  // This is needed to register the XLA_GPU device.
+  std::vector<Device*> devices;
+  TF_ASSERT_OK(DeviceFactory::AddDevices(
+      SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
+
+  // Scope::ToGraph loses the assigned device name since it goes through
+  // GraphDef/NodeDef which does not have a field for the assigned device name.
+  Node* n = FindNodeByName(*graph, "shape");
+  ASSERT_NE(n, nullptr);
+  n->set_assigned_device_name(
+      "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
+
+  TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+  EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+
+  for (Device* d : devices) {
+    delete d;
+  }
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 1ba4a5e..56e35c0 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -165,7 +165,7 @@
 using ResourceOp = std::pair<int, XlaResourceOpKind>;
 
 string ResourceOpToString(const ResourceOp& resource_op) {
-  return strings::StrCat(
+  return absl::StrCat(
       resource_op.first, ": ",
       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
 }
@@ -257,11 +257,11 @@
   std::vector<string> elements_debug_string;
   std::transform(resource_op_set.begin(), resource_op_set.end(),
                  std::back_inserter(elements_debug_string), ResourceOpToString);
-  return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
+  return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
 }
 
 string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
-  return strings::StrCat(
+  return absl::StrCat(
       "[", n.name(), ": ", n.type_string(), "(",
       XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
 }
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 4f2fabd..f85121c 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -17,6 +17,7 @@
 
 #include <unordered_map>
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/graph/control_flow.h"
@@ -52,8 +53,8 @@
   };
 
   string description;
-  strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
-                     node_name(dst), " would create a cycle.\n");
+  absl::StrAppend(&description, "Edge from ", node_name(src), " to ",
+                  node_name(dst), " would create a cycle.\n");
   path.resize(path_size);
   for (int32 node_id : path) {
     string ascii_art;
@@ -64,7 +65,7 @@
     } else {
       ascii_art = "+-- ";
     }
-    strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
+    absl::StrAppend(&description, ascii_art, node_name(node_id), "\n");
   }
   return description;
 }
@@ -186,7 +187,7 @@
   return Status::OK();
 }
 
-absl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
+absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node) {
   const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
   if (attr_value == nullptr) {
     return absl::nullopt;
@@ -209,6 +210,8 @@
   node_def->mutable_attr()->erase(kXlaClusterAttr);
 }
 
+void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
+
 Status AdjustCycleDetectionGraphForResourceOps(
     const Graph* graph, const FunctionLibraryDefinition* flib_def,
     const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index b0439a6..ba218f3 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -47,11 +47,14 @@
 
 // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
 // otherwise returns nullopt.
-absl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
+absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
 
 // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
 void RemoveFromXlaCluster(NodeDef* node_def);
 
+// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(Node* node);
+
 // Returns true if `node` has a DT_RESOURCE typed input or output.
 bool HasResourceInputOrOutput(const Node& node);
 
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index ef6b0e6..3aa9e9c 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -67,12 +67,12 @@
 string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
   string result = sig.name;
   for (const auto& a : sig.arg_types) {
-    strings::StrAppend(&result, ",", DataTypeString(a.first),
-                       a.second.DebugString());
+    absl::StrAppend(&result, ",", DataTypeString(a.first),
+                    a.second.DebugString());
   }
 
   for (const auto& v : sig.arg_values) {
-    strings::StrAppend(&result, "; ", v.DebugString());
+    absl::StrAppend(&result, "; ", v.DebugString());
   }
   return result;
 }
@@ -259,7 +259,7 @@
     const XlaCompiler::CompileOptions& compile_options,
     bool compile_single_op) {
   CHECK_NE(executable, nullptr);
-  VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
+  VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
 
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "num_inputs=" << ctx->num_inputs()
@@ -310,7 +310,7 @@
   // cache eviction.
   mutex_lock entry_lock(entry->mu);
   if (!entry->compiled) {
-    VLOG(1) << "Compilation cache miss for signature: "
+    VLOG(2) << "Compilation cache miss for signature: "
             << SignatureDebugString(signature);
     tensorflow::Env* env = tensorflow::Env::Default();
     const uint64 compile_start_us = env->NowMicros();
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index f31879a..51797de 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -148,10 +148,9 @@
   }
 
   const DeviceAttributes attrs = Device::BuildDeviceAttributes(
-      strings::StrCat(name_prefix, "/device:", device_name, ":",
-                      device_ordinal),
+      absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal),
       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
-      strings::StrCat("device: ", device_name, " device"));
+      absl::StrCat("device: ", device_name, " device"));
 
   device->reset(
       new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index ee07c5c..af83c79 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -203,7 +203,7 @@
 }
 
 void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
-                                               StringPiece tensor_name,
+                                               absl::string_view tensor_name,
                                                Device* device,
                                                Tensor* cpu_tensor,
                                                StatusCallback done) {
@@ -339,7 +339,7 @@
 }
 
 void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
-                                             StringPiece tensor_name,
+                                             absl::string_view tensor_name,
                                              Device* device, Tensor* cpu_tensor,
                                              StatusCallback done) {
   manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 2e74453..df82421 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -57,7 +57,7 @@
   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                              Tensor* device_tensor, StatusCallback done) const;
   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
-                             StringPiece tensor_name, Device* device,
+                             absl::string_view tensor_name, Device* device,
                              Tensor* cpu_tensor, StatusCallback done);
 
   void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
@@ -111,7 +111,7 @@
                              Tensor* device_tensor,
                              StatusCallback done) const override;
   void CopyDeviceTensorToCPU(const Tensor* device_tensor,
-                             StringPiece tensor_name, Device* device,
+                             absl::string_view tensor_name, Device* device,
                              Tensor* cpu_tensor, StatusCallback done) override;
   void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
                                 const StatusCallback& done);
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 13da5d2..49c8582 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -198,33 +198,33 @@
                                                                                \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"),            \
-      GeneratorDatasetOp);                                                     \
+      data::GeneratorDatasetOp);                                               \
   REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")                              \
                               .Device(DEVICE)                                  \
                               .HostMemory("buffer_size")                       \
                               .HostMemory("input_dataset")                     \
                               .HostMemory("handle"),                           \
-                          PrefetchDatasetOp);                                  \
+                          data::PrefetchDatasetOp);                            \
                                                                                \
   REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE),                   \
-                          IteratorHandleOp);                                   \
+                          data::IteratorHandleOp);                             \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("MakeIterator").Device(DEVICE).HostMemory("dataset"),               \
-      MakeIteratorOp);                                                         \
+      data::MakeIteratorOp);                                                   \
   REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE),            \
-                          AnonymousIteratorHandleOp);                          \
+                          data::AnonymousIteratorHandleOp);                    \
   REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE),              \
-                          IteratorGetNextOp);                                  \
+                          data::IteratorGetNextOp);                            \
   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE),          \
-                          IteratorGetNextSyncOp);                              \
+                          data::IteratorGetNextSyncOp);                        \
   REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")                       \
                               .Device(DEVICE)                                  \
                               .HostMemory("string_handle"),                    \
-                          IteratorToStringHandleOp);                           \
+                          data::IteratorToStringHandleOp);                     \
   REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")                   \
                               .Device(DEVICE)                                  \
                               .HostMemory("string_handle"),                    \
-                          IteratorFromStringHandleOp);                         \
+                          data::IteratorFromStringHandleOp);                   \
   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp)              \
                               .Device(DEVICE)                                  \
                               .HostMemory("output")                            \
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
index 07cfab6..bc0db55 100644
--- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -20,6 +20,7 @@
 #include <unordered_map>
 #include <unordered_set>
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/jit/deadness_analysis.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -326,7 +327,7 @@
       string& name = cluster_names[cluster];
 
       if (name.empty()) {
-        name = strings::StrCat("cluster_", cluster_sequence_num++);
+        name = absl::StrCat("cluster_", cluster_sequence_num++);
       }
       n->AddAttr(kXlaClusterAttr, name);
       VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 4c9bb2e..d95da63 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -122,7 +122,7 @@
   std::shared_ptr<se::Event> definition_event_;
   // A list of all streams for which the tensor's content is defined for any
   // newly enqueued command.
-  gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
+  absl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
   mutex mu_;
 };
 
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 34defe1..2176eae 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -581,6 +581,7 @@
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework",
         "//tensorflow/python:platform_test",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
 
@@ -1103,6 +1104,7 @@
         "//tensorflow/core:test",
         "//tensorflow/core:testlib",
         "//tensorflow/core/kernels:ops_util",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -1196,7 +1198,7 @@
 
 tf_xla_py_test(
     name = "xla_ops_test",
-    size = "small",
+    size = "medium",
     srcs = ["xla_ops_test.py"],
     disabled_backends = ["cpu_ondemand"],
     deps = [
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index df0f214..058576b 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -56,7 +56,7 @@
       # TODO: test fails for float16 due to excessive precision requirements.
       if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
         continue
-      with self.test_session(), self.test_scope():
+      with self.cached_session(), self.test_scope():
         variable_scope.get_variable_scope().set_use_resource(True)
 
         # Initialize variables for numpy implementation.
@@ -98,7 +98,7 @@
       # TODO: test fails for float16 due to excessive precision requirements.
       if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
         continue
-      with self.test_session(), self.test_scope():
+      with self.cached_session(), self.test_scope():
         variable_scope.get_variable_scope().set_use_resource(True)
 
         # Initialize variables for numpy implementation.
@@ -140,7 +140,7 @@
       # TODO: test fails for float16 due to excessive precision requirements.
       if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
         continue
-      with self.test_session(), self.test_scope():
+      with self.cached_session(), self.test_scope():
         variable_scope.get_variable_scope().set_use_resource(True)
 
         # Initialize variables for numpy implementation.
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 04f3b3e..0af74c2 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -58,7 +58,8 @@
     Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
     """
 
-    os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
+    os.environ["TF_XLA_FLAGS"] = (
+        "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
     config = config_pb2.ConfigProto()
     config.graph_options.optimizer_options.global_jit_level = (
         config_pb2.OptimizerOptions.ON_1)
@@ -77,7 +78,7 @@
 
     labels = GetRunMetadataLabels(run_metadata)
     self.assertEqual(1, XlaLaunchOpCount(labels))
-    self.assertFalse(InLabels(labels, "ListDiff"))
+    self.assertFalse(InLabels(labels, "MatMult"))
 
   def testDenseLayerJitScopeDefinedShape(self):
     """Tests that the dense layer node is properly compiled in jit scope.
@@ -128,7 +129,7 @@
 
     labels = GetRunMetadataLabels(run_metadata)
     self.assertEqual(2, XlaLaunchOpCount(labels))
-    self.assertFalse(InLabels(labels, "ListDiff"))
+    self.assertFalse(InLabels(labels, "MatMult"))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6e0db54..0839fb1 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -489,8 +489,9 @@
   def testElementWiseClustering(self):
     arg0 = np.random.rand(2, 2).astype(np.float32)
     arg1 = np.random.rand(2, 2).astype(np.float32)
-    os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
-                                  "--tf_xla_cpu_global_jit")
+    os.environ["TF_XLA_FLAGS"] = (
+        "--tf_xla_fusion_only=true "
+        "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
     tf_op, tf_count = self.simpleTest(arg0, arg1,
                                       config_pb2.OptimizerOptions.OFF)
     self.assertEqual(0, tf_count)
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 9222db4..c61965b 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -17,6 +17,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.compiler.tests import xla_test
@@ -26,38 +27,167 @@
 from tensorflow.python.platform import test
 
 
-class MatrixBandPartTest(xla_test.XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase):
 
-  def _testMatrixBandPart(self, dtype, shape):
-    with self.cached_session():
-      batch_shape = shape[:-2]
-      mat = np.ones(shape).astype(dtype)
-      batch_mat = np.tile(mat, batch_shape + [1, 1])
-      for lower in -1, 0, 1, shape[-2] - 1:
-        for upper in -1, 0, 1, shape[-1] - 1:
-          band_np = mat
-          if lower >= 0:
-            band_np = np.triu(band_np, -lower)
-          if upper >= 0:
-            band_np = np.tril(band_np, upper)
-          if batch_shape:
-            band_np = np.tile(band_np, batch_shape + [1, 1])
-
-          placeholder = array_ops.placeholder(dtype)
-          with self.test_scope():
-            band = array_ops.matrix_band_part(
-                placeholder,
-                constant_op.constant(lower, dtype=dtypes.int32),
-                constant_op.constant(upper, dtype=dtypes.int32))
-            feed_dict = {placeholder: batch_mat}
-            self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
-
-  def testMatrixBandPart(self):
+  @parameterized.parameters(
+      {
+          'batch_shape': [],
+          'rows': 1,
+          'cols': 1
+      },
+      {
+          'batch_shape': [],
+          'rows': 1,
+          'cols': 2
+      },
+      {
+          'batch_shape': [],
+          'rows': 1,
+          'cols': 7
+      },
+      {
+          'batch_shape': [],
+          'rows': 2,
+          'cols': 1
+      },
+      {
+          'batch_shape': [],
+          'rows': 2,
+          'cols': 2
+      },
+      {
+          'batch_shape': [],
+          'rows': 2,
+          'cols': 7
+      },
+      {
+          'batch_shape': [],
+          'rows': 7,
+          'cols': 1
+      },
+      {
+          'batch_shape': [],
+          'rows': 7,
+          'cols': 2
+      },
+      {
+          'batch_shape': [],
+          'rows': 7,
+          'cols': 7
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 1,
+          'cols': 1
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 1,
+          'cols': 2
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 1,
+          'cols': 7
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 2,
+          'cols': 1
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 2,
+          'cols': 2
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 2,
+          'cols': 7
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 7,
+          'cols': 1
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 7,
+          'cols': 2
+      },
+      {
+          'batch_shape': [2,],
+          'rows': 7,
+          'cols': 7
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 1,
+          'cols': 1
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 1,
+          'cols': 2
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 1,
+          'cols': 7
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 2,
+          'cols': 1
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 2,
+          'cols': 2
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 2,
+          'cols': 7
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 7,
+          'cols': 1
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 7,
+          'cols': 2
+      },
+      {
+          'batch_shape': [1, 3, 2],
+          'rows': 7,
+          'cols': 7
+      },
+  )
+  def testMatrixBandPart(self, batch_shape, rows, cols):
     for dtype in self.float_types:
-      for batch_shape in [[], [2,], [1, 3, 2]]:
-        for rows in 1, 2, 7:
-          for cols in 1, 2, 7:
-            self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+      with self.cached_session():
+        mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
+        batch_mat = np.tile(mat, batch_shape + [1, 1])
+        for lower in -1, 0, 1, rows - 1:
+          for upper in -1, 0, 1, cols - 1:
+            band_np = mat
+            if lower >= 0:
+              band_np = np.triu(band_np, -lower)
+            if upper >= 0:
+              band_np = np.tril(band_np, upper)
+            if batch_shape:
+              band_np = np.tile(band_np, batch_shape + [1, 1])
+
+            placeholder = array_ops.placeholder(dtype)
+            with self.test_scope():
+              band = array_ops.matrix_band_part(
+                  placeholder, constant_op.constant(lower, dtype=dtypes.int32),
+                  constant_op.constant(upper, dtype=dtypes.int32))
+              feed_dict = {placeholder: batch_mat}
+              self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 0faf0fd..bddda6f 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,8 @@
 #include <random>
 #include <unordered_map>
 
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/core/common_runtime/device.h"
@@ -61,7 +63,6 @@
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/public/session.h"
@@ -81,7 +82,7 @@
 bool tf_xla_test_use_jit = true;
 
 string LocalDeviceToFullDeviceName(const string& device) {
-  return strings::StrCat("/job:localhost/replica:0/task:0/device:", device);
+  return absl::StrCat("/job:localhost/replica:0/task:0/device:", device);
 }
 
 constexpr std::array<DataType, 5> kAllXlaTypes = {
@@ -107,11 +108,12 @@
 
   // Sets an attribute.
   template <class T>
-  OpTestBuilder& Attr(StringPiece attr_name, T&& value);
+  OpTestBuilder& Attr(absl::string_view attr_name, T&& value);
 
   // Overload needed to allow {...} expressions for value.
   template <class T>
-  OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
+  OpTestBuilder& Attr(absl::string_view attr_name,
+                      std::initializer_list<T> value);
 
   // Adds nodes that executes the operator under test on 'device' to 'graphdef'.
   // If 'use_jit' is true, marks the operator under test to be compiled by XLA.
@@ -185,13 +187,13 @@
 }
 
 template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) {
   AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
   return *this;
 }
 
 template <class T>
-OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name,
+OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name,
                                    std::initializer_list<T> value) {
   Attr<std::initializer_list<T>>(attr_name, std::move(value));
   return *this;
@@ -209,7 +211,7 @@
 
   NodeDef* test_def = graphdef->add_node();
   *test_def = node_def_;
-  test_def->set_name(strings::StrCat(name_prefix, "_op_under_test"));
+  test_def->set_name(absl::StrCat(name_prefix, "_op_under_test"));
   test_def->set_device(device);
   AddDefaultsToNodeDef(*op_def, test_def);
   if (use_jit) {
@@ -224,7 +226,7 @@
   // Build feed and fetch nodes.
   for (int i = 0; i < input_types.size(); ++i) {
     NodeDef* def = graphdef->add_node();
-    string name = strings::StrCat(name_prefix, "_input_", i);
+    string name = absl::StrCat(name_prefix, "_input_", i);
     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
                            .Device(device)
                            .Attr("dtype", input_types[i])
@@ -235,7 +237,7 @@
 
   for (int i = 0; i < output_types.size(); ++i) {
     NodeDef* def = graphdef->add_node();
-    string name = strings::StrCat(name_prefix, "_output_", i);
+    string name = absl::StrCat(name_prefix, "_output_", i);
     TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
                            .Device(device)
                            .Attr("T", output_types[i])
@@ -726,11 +728,11 @@
 
 template <typename T>
 string Str(T x) {
-  return strings::StrCat(x);
+  return absl::StrCat(x);
 }
 template <>
 string Str<complex64>(complex64 x) {
-  return strings::StrCat("(", x.real(), ", ", x.imag(), ")");
+  return absl::StrCat("(", x.real(), ", ", x.imag(), ")");
 }
 
 template <typename T>
@@ -740,11 +742,11 @@
   auto Ty = y.flat<T>();
   for (int i = 0; i < Tx.size(); ++i) {
     if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
-      return errors::InvalidArgument(strings::StrCat(
-          i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ",
-          Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(),
-          "atol = ", atol, " rtol = ", rtol,
-          " tol = ", atol + rtol * Abs(Tx(i))));
+      return errors::InvalidArgument(
+          absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)),
+                       " vs. ", Str(Ty(i)), ". x = ", x.DebugString(),
+                       "y = ", y.DebugString(), "atol = ", atol,
+                       " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i))));
     }
   }
   return Status::OK();
@@ -756,7 +758,7 @@
   auto Ty = y.flat<T>();
   for (int i = 0; i < Tx.size(); ++i) {
     if (Tx(i) != Ty(i)) {
-      return errors::InvalidArgument(strings::StrCat(
+      return errors::InvalidArgument(absl::StrCat(
           i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i),
           ". x = ", x.DebugString(), "y = ", y.DebugString()));
     }
@@ -771,14 +773,14 @@
 Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
                        double rtol) {
   if (a.dtype() != b.dtype()) {
-    return errors::InvalidArgument(strings::StrCat(
+    return errors::InvalidArgument(absl::StrCat(
         "Tensors have different types: ", DataTypeString(a.dtype()), " and ",
         DataTypeString(b.dtype())));
   }
   if (!a.IsSameSize(b)) {
-    return errors::InvalidArgument(strings::StrCat(
-        "Tensors have different shapes: ", a.shape().DebugString(), " and ",
-        b.shape().DebugString()));
+    return errors::InvalidArgument(
+        absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(),
+                     " and ", b.shape().DebugString()));
   }
 
   switch (a.dtype()) {
@@ -827,7 +829,7 @@
   }
 
   string cpu_device =
-      LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0"));
+      LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0"));
   string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr);
 
   DeviceNameUtils::ParsedName parsed_name;
@@ -842,7 +844,7 @@
   std::vector<string> expected_inputs, test_inputs;
   std::vector<string> expected_fetches, test_fetches;
   Status status = builder.BuildGraph(
-      strings::StrCat("test", num_tests_, "_expected"), cpu_device,
+      absl::StrCat("test", num_tests_, "_expected"), cpu_device,
       /* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
       &expected_inputs, &expected_fetches);
   if (!status.ok()) {
@@ -851,7 +853,7 @@
   }
 
   NodeDef* node_def;
-  status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
+  status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"),
                               test_device, tf_xla_test_use_jit, &graph,
                               &node_def, &test_inputs, &test_fetches);
   if (!status.ok()) {
diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py
index 84c6777..96e0b07 100644
--- a/tensorflow/compiler/tests/reshape_op_test.py
+++ b/tensorflow/compiler/tests/reshape_op_test.py
@@ -33,7 +33,7 @@
                                   ('64_bit_index', dtypes.int64))
   def testBasic(self, index_dtype):
     for dtype in self.numeric_types:
-      with self.test_session():
+      with self.cached_session():
         i = array_ops.placeholder(dtype, shape=[2, 3])
         with self.test_scope():
           shape = constant_op.constant([3, 2], dtype=index_dtype)
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index b2f026d..1e600c4 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@
 from tensorflow.compiler.tf2xla.python import xla
 from tensorflow.compiler.xla import xla_data_pb2
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import function
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import googletest
@@ -97,9 +98,9 @@
         args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)),
         expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32))
 
-  PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT,
-                      xla_data_pb2.PrecisionConfigProto.HIGH,
-                      xla_data_pb2.PrecisionConfigProto.HIGHEST)
+  PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT,
+                      xla_data_pb2.PrecisionConfig.HIGH,
+                      xla_data_pb2.PrecisionConfig.HIGHEST)
 
   @parameterized.parameters(*PRECISION_VALUES)
   def testConv(self, precision):
@@ -120,7 +121,7 @@
         dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims))
         precision_config = None
         if precision:
-          precision_config = xla_data_pb2.PrecisionConfigProto()
+          precision_config = xla_data_pb2.PrecisionConfig()
           precision_config.operand_precision.extend([precision, precision])
         return xla.conv(
             lhs,
@@ -151,7 +152,7 @@
         dnums.rhs_batch_dimensions.append(0)
         precision_config = None
         if precision:
-          precision_config = xla_data_pb2.PrecisionConfigProto()
+          precision_config = xla_data_pb2.PrecisionConfig()
           precision_config.operand_precision.extend([precision, precision])
         return xla.dot_general(
             lhs,
@@ -296,6 +297,44 @@
       self._assertOpOutputMatchesExpected(
           lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
 
+  def testDynamicSlice(self):
+    for dtype in self.numeric_types:
+      self._assertOpOutputMatchesExpected(
+          xla.dynamic_slice,
+          args=(np.arange(1000,
+                          dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+                np.array([5, 7, 3]), np.array([2, 3, 2])),
+          expected=np.array(
+              np.array([[[573, 574], [583, 584], [593, 594]],
+                        [[673, 674], [683, 684], [693, 694]]]),
+              dtype=dtype))
+
+  def testDynamicSliceWithIncorrectStartIndicesShape(self):
+    with self.test_session() as session:
+      with self.test_scope():
+        output = xla.dynamic_slice(
+            np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+            np.array([5, 7]), np.array([2, 3, 4]))
+      with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+        session.run(output)
+      self.assertRegexpMatches(
+          invalid_arg_error.exception.message,
+          (r'^start_indices must be a vector with length equal to input rank, '
+           r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+  def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+    with self.test_session() as session:
+      with self.test_scope():
+        output = xla.dynamic_slice(
+            np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+            np.array([5, 7, 3]), np.array([2, 3]))
+      with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+        session.run(output)
+      self.assertRegexpMatches(
+          invalid_arg_error.exception.message,
+          (r'^size_indices must be a vector with length equal to input rank, '
+           r'but input rank is 3 and size_indices has shape \[2\].*'))
+
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 0797b2c..d549e7b 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -76,6 +76,7 @@
     deps = [
         ":common",
         ":dump_graph",
+        ":functionalize_control_flow",
         ":tf2xla_proto",
         ":tf2xla_util",
         ":xla_compiler",
@@ -188,9 +189,9 @@
     deps = [
         ":common",
         ":dump_graph",
-        ":functionalize_control_flow",
         ":host_compute_metadata_proto",
         ":sharding_util",
+        ":side_effect_util",
         ":tf2xla_util",
         "//tensorflow/compiler/tf2xla/lib:util",
         "//tensorflow/compiler/xla:literal",
@@ -283,6 +284,7 @@
     deps = [
         ":sharding_util",
         ":tf2xla_proto",
+        "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
@@ -291,6 +293,7 @@
         "//tensorflow/core:graph",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
     ],
 )
@@ -358,6 +361,7 @@
     name = "xla_compiler_test",
     srcs = ["xla_compiler_test.cc"],
     deps = [
+        ":side_effect_util",
         ":xla_compiler",
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:function_ops",
@@ -369,6 +373,7 @@
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:local_client",
+        "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/service:cpu_plugin",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/core:core_cpu_internal",
@@ -433,6 +438,7 @@
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -474,6 +480,7 @@
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
@@ -501,12 +508,24 @@
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:optional",
     ],
 )
 
 cc_library(
+    name = "functionalize_control_flow_pass_registration",
+    srcs = [
+        "functionalize_control_flow_pass_registration.cc",
+    ],
+    deps = [
+        ":functionalize_control_flow",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
     name = "functionalize_while",
     srcs = [
         "functionalize_while.cc",
@@ -515,6 +534,7 @@
         "functionalize_while.h",
     ],
     deps = [
+        ":functionalize_cond",
         ":functionalize_control_flow_util",
         ":tf2xla_util",
         "//tensorflow/compiler/jit:union_find",
@@ -525,6 +545,7 @@
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:graph",
+        "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:optional",
     ],
@@ -539,6 +560,7 @@
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/cc:function_ops",
+        "//tensorflow/cc:functional_ops",
         "//tensorflow/cc:ops",
         "//tensorflow/cc:resource_variable_ops",
         "//tensorflow/compiler/tf2xla/cc:xla_ops",
@@ -609,11 +631,10 @@
     srcs = ["resource_operation_table.cc"],
     hdrs = ["resource_operation_table.h"],
     deps = [
-        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:ops",
-        "//tensorflow/core:protos_all_cc",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
     ],
 )
 
@@ -630,3 +651,12 @@
         "@com_google_absl//absl/strings",
     ],
 )
+
+cc_library(
+    name = "side_effect_util",
+    srcs = ["side_effect_util.cc"],
+    hdrs = ["side_effect_util.h"],
+    deps = [
+        "//tensorflow/core:core_cpu",
+    ],
+)
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index e8673d7..922ae7c 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -26,8 +26,9 @@
 // Backwards dataflow analysis that finds arguments to a graph that must be
 // compile-time constants.
 Status BackwardsConstAnalysis(const Graph& g,
-                              std::vector<bool>* compile_time_const_args,
-                              std::vector<bool>* compile_time_const_nodes) {
+                              std::vector<bool>* compile_time_const_arg_indices,
+                              std::vector<bool>* compile_time_const_nodes,
+                              std::function<bool(const Edge&)> edge_filter) {
   // Operators that don't look at the data of their inputs, just the shapes.
   const std::unordered_set<string> metadata_ops = {
       "Rank",
@@ -45,8 +46,7 @@
   }
 
   Status status;
-  auto visit = [&status, &metadata_ops, compile_time_const_nodes,
-                compile_time_const_args](Node* node) {
+  auto visit = [&](Node* node) {
     if (!status.ok()) return;
 
     // If this is a metadata-only op, don't propagate the const requirement.
@@ -59,13 +59,13 @@
         int index;
         status = GetNodeAttr(node->attrs(), "index", &index);
         if (!status.ok()) return;
-        if (compile_time_const_args) {
-          (*compile_time_const_args)[index] = true;
+        if (compile_time_const_arg_indices) {
+          (*compile_time_const_arg_indices)[index] = true;
         }
         return;
       }
       for (const Edge* pred : node->in_edges()) {
-        if (!pred->IsControlEdge()) {
+        if (!pred->IsControlEdge() && edge_filter(*pred)) {
           (*compile_time_const_nodes)[pred->src()->id()] = true;
         }
       }
@@ -88,7 +88,8 @@
 
       for (Edge const* edge : node->in_edges()) {
         if (edge->dst_input() >= name_range->second.first &&
-            edge->dst_input() < name_range->second.second) {
+            edge->dst_input() < name_range->second.second &&
+            edge_filter(*edge)) {
           (*compile_time_const_nodes)[edge->src()->id()] = true;
         }
       }
@@ -97,7 +98,8 @@
 
   // Post-order traversal visits nodes in reverse topological order for an
   // acyclic graph.
-  DFS(g, {}, visit);
+  DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
+      [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
   return status;
 }
 
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
index af57e5a..49b3c6d 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.h
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -32,9 +32,13 @@
 //
 // The ids of the nodes in `graph` that must be constant are returned in
 // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
-Status BackwardsConstAnalysis(const Graph& graph,
+//
+// Only propagate const-ness along edges for which `edge_filter` returns true.
+Status BackwardsConstAnalysis(const Graph& g,
                               std::vector<bool>* compile_time_const_arg_indices,
-                              std::vector<bool>* compile_time_const_nodes);
+                              std::vector<bool>* compile_time_const_nodes,
+                              std::function<bool(const Edge&)> edge_filter =
+                                  [](const Edge& e) { return true; });
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index 24616c0..380c6a7 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -18,8 +18,8 @@
 
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 
@@ -52,9 +52,9 @@
 
   string filename = name;
   if (count > 0) {
-    strings::StrAppend(&filename, "_", count);
+    absl::StrAppend(&filename, "_", count);
   }
-  strings::StrAppend(&filename, ".pbtxt");
+  absl::StrAppend(&filename, ".pbtxt");
   return filename;
 }
 
@@ -69,7 +69,7 @@
                  << proto_type << ": " << status;
     return "(unavailable)";
   }
-  string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name));
+  string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name));
   status = WriteTextProto(Env::Default(), filepath, proto);
   if (!status.ok()) {
     LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index b5667ca..db256e5 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,30 +34,16 @@
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 
 using xla::StatusOr;
 
 namespace tensorflow {
 namespace functionalize_cond {
 
-string DebugString(const CondStateMap::CondNode& node) {
-  return node.ToString();
-}
-
 // TODO(jpienaar): Move to OutputTensor.
 string DebugString(const OutputTensor& tensor) {
-  return strings::StrCat(tensor.node->name(), ":", tensor.index);
-}
-
-string DebugString(CondStateMap::CondId cond_state) {
-  if (cond_state == nullptr || cond_state->empty()) return "[]";
-  return strings::StrCat(
-      "[",
-      absl::StrJoin(*cond_state, ", ",
-                    [](string* output, const CondStateMap::CondNode& node) {
-                      strings::StrAppend(output, node.ToString());
-                    }),
-      "]");
+  return absl::StrCat(tensor.node->name(), ":", tensor.index);
 }
 
 string Branch_Name(BranchType b) {
@@ -73,6 +59,24 @@
   }
 }
 
+string DebugString(StateMap::CondId cond_state) {
+  if (cond_state == nullptr || cond_state->empty()) return "{}";
+  using value_type = StateMap::CondState::value_type;
+  return absl::StrCat(
+      "{",
+      absl::StrJoin(*cond_state, ", ",
+                    [](string* output, const value_type& pred_branch) {
+                      const OutputTensor& pred = pred_branch.first;
+                      const BranchType& branch = pred_branch.second;
+                      if (branch == BranchType::kNeither)
+                        absl::StrAppend(output, "d");
+                      else
+                        absl::StrAppend(output, "s(", DebugString(pred), ",",
+                                        Branch_Name(branch), ")");
+                    }),
+      "}");
+}
+
 // Returns the predicate of a switch.
 Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) {
   const Edge* pred_edge;
@@ -86,64 +90,65 @@
   return Status::OK();
 }
 
-CondStateMap::CondNode::CondNode(Type type, Node* switch_node,
-                                 BranchType branch)
-    : type(type), branch(branch) {
-  if (type == Type::kSwitch) {
-    TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate));
+Status GetSwitchValue(const Node& switch_node, OutputTensor* val) {
+  const Edge* val_edge;
+  TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge));
+  *val = OutputTensor(val_edge->src(), val_edge->src_output());
+  return Status::OK();
+}
+
+bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs,
+                                            const OutputTensor& rhs) const {
+  return (lhs.node->id() < rhs.node->id()) ||
+         (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index);
+}
+
+struct CondStateLess {
+  bool operator()(const StateMap::CondState::value_type& lhs,
+                  const StateMap::CondState::value_type& rhs) const {
+    if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first))
+      return true;
+    if (lhs.first.node->id() == rhs.first.node->id() &&
+        lhs.first.index == rhs.first.index)
+      return lhs.second < rhs.second;
+    return false;
   }
-}
+};
 
-string CondStateMap::CondNode::ToString() const {
-  switch (type) {
-    case Type::kSwitch:
-      return strings::StrCat("s(", DebugString(predicate), ",",
-                             Branch_Name(branch), ")");
-    case Type::kMerge:
-      return "m";
-    case Type::kDead:
-      return "d";
-  }
-}
-
-bool CondStateMap::CondNode::operator==(const CondNode& other) const {
-  if (type != Type::kSwitch) return type == other.type;
-  return type == other.type && predicate == other.predicate &&
-         branch == other.branch;
-}
-
-bool CondStateMap::CondNode::operator!=(const CondNode& other) const {
-  return !(*this == other);
-}
-
-CondStateMap::CondStateMap(Graph* graph) {
+StateMap::StateMap(Graph* graph) {
   node_to_condid_map_.resize(graph->num_node_ids());
+  node_to_ancestorid_map_.resize(graph->num_node_ids());
   // Initialize the dead state (empty state is designated with a nullptr).
-  dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)});
+  dead_id_ = GetCondId(
+      {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)});
 }
 
-bool CondStateMap::IsDead(CondStateMap::CondId id) const {
-  return id == dead_id_;
+bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; }
+
+bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; }
+
+size_t StateMap::Hash::operator()(const StateMap::CondState& map) const {
+  if (map.empty()) return 0;
+  // Compute hash of the front element.
+  auto it = map.begin();
+  size_t h = Hash64Combine(OutputTensor::Hash()(it->first),
+                           hash<BranchType>()(it->second));
+  for (++it; it != map.end(); ++it) {
+    // Combine the has with the different elements in the map.
+    h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first),
+                                       hash<BranchType>()(it->second)));
+  }
+  return h;
 }
 
-bool CondStateMap::IsEmpty(CondStateMap::CondId id) const {
-  return id == nullptr;
-}
-
-size_t CondStateMap::CondHash::operator()(
-    const CondStateMap::CondNode& item) const {
-  return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate),
-                                     hash<BranchType>()(item.branch)),
-                       hash<CondStateMap::CondNode::Type>()(item.type));
-}
-
-size_t CondStateMap::CondHash::operator()(
-    const CondStateMap::CondState& vec) const {
-  if (vec.empty()) return 0;
-  size_t h = (*this)(vec.front());
-  auto it = vec.begin();
-  for (++it; it != vec.end(); ++it) {
-    h = Hash64Combine(h, (*this)(*it));
+size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const {
+  if (map.empty()) return 0;
+  // Compute hash of the front element.
+  auto it = map.begin();
+  size_t h = hash<Node*>()(*it);
+  for (++it; it != map.end(); ++it) {
+    // Combine the has with the different elements in the map.
+    h = Hash64Combine(h, hash<Node*>()(*it));
   }
   return h;
 }
@@ -155,8 +160,8 @@
       : src(src), src_output(src_output) {}
 
   string ToString() const {
-    return strings::StrCat("src=", src->name(), ":", src_output,
-                           " switches=", NodesToString(switches));
+    return absl::StrCat("src=", src->name(), ":", src_output,
+                        " switches=", NodesToString(switches));
   }
 
   Node* src;
@@ -167,58 +172,76 @@
 using CondArgNodes = std::vector<CondArgNode>;
 
 string DebugString(const CondArgNodes& nodes) {
-  return strings::StrCat(
+  return absl::StrCat(
       "[",
       absl::StrJoin(nodes, ", ",
                     [](string* output, const CondArgNode& node) {
-                      strings::StrAppend(output, node.ToString());
+                      absl::StrAppend(output, node.ToString());
                     }),
       "]");
 }
 
-CondStateMap::CondId CondStateMap::LookupId(const Node* node) const {
+StateMap::CondId StateMap::LookupCondId(const Node* node) const {
   if (node->id() < node_to_condid_map_.size())
     return node_to_condid_map_[node->id()];
-  return added_node_mapping_.at(node->id());
+  return added_node_condid_mapping_.at(node->id());
 }
 
-CondStateMap::CondId CondStateMap::GetUniqueId(
-    const CondStateMap::CondState& state) {
+StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) {
   if (state.empty()) return nullptr;
   return &*condstate_set_.insert(state).first;
 }
 
-const CondStateMap::CondState& CondStateMap::LookupState(
-    const Node* node) const {
-  return *LookupId(node);
-}
-
-void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) {
+void StateMap::ResetCondId(const Node* node, StateMap::CondId id) {
   if (node->id() < node_to_condid_map_.size())
     node_to_condid_map_[node->id()] = id;
   else
-    added_node_mapping_[node->id()] = id;
+    added_node_condid_mapping_[node->id()] = id;
 }
 
-void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); }
-
-string CondStateMap::CondStateToString(const Node* node) const {
-  return CondStateToString(LookupId(node));
+StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const {
+  if (node->id() < node_to_ancestorid_map_.size())
+    return node_to_ancestorid_map_[node->id()];
+  return added_node_ancestorid_mapping_.at(node->id());
 }
 
-string CondStateMap::CondStateToString(CondStateMap::CondId id) const {
+StateMap::AncestorId StateMap::GetAncestorId(
+    const StateMap::AncestorState& state) {
+  if (state.empty()) return nullptr;
+  return &*ancestorstate_set_.insert(state).first;
+}
+
+void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
+  if (node->id() < node_to_ancestorid_map_.size())
+    node_to_ancestorid_map_[node->id()] = id;
+  else
+    added_node_ancestorid_mapping_[node->id()] = id;
+}
+
+void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
+
+string StateMap::CondStateToString(const Node* node) const {
+  return CondStateToString(LookupCondId(node));
+}
+
+string StateMap::CondStateToString(StateMap::CondId id) const {
   return DebugString(id);
 }
 
+string StateMap::AncestorStateToString(const Node* node) const {
+  if (auto id = LookupAncestorId(node)) return NodesToString(*id);
+  return "{}";
+}
+
 FunctionalizeCond::FunctionalizeCond(Graph* graph,
                                      FunctionLibraryDefinition* library)
-    : cond_state_map_(graph), library_(library), graph_(graph) {}
+    : state_map_(graph), library_(library), graph_(graph) {}
 
 // Class representing the merge/switch nodes that will become a conditional.
 class Conditional {
  public:
   Conditional(OutputTensor predicate, FunctionalizeCond* parent,
-              CondStateMap* cond_state_map);
+              StateMap* cond_state_map);
 
   // Adds merge node that is part of this conditional.
   Status AddMerge(Node* m);
@@ -247,6 +270,10 @@
   // Adds switch node that is part of this conditional.
   Status AddSwitch(Node* s);
 
+  // Adds a switch node along the edge and rewire the edge to go via the switch.
+  Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+                                Graph* graph);
+
   // Internal name of conditional. The name is based on the first merge node
   // added.
   string name() const;
@@ -255,7 +282,7 @@
   FunctionalizeCond* parent_;
 
   // Mapping between nodes and their cond state.
-  CondStateMap* cond_state_map_;
+  StateMap* state_map_;
 
   // The predicate of the conditional.
   OutputTensor predicate_;
@@ -292,8 +319,8 @@
 };
 
 Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent,
-                         CondStateMap* cond_state_map)
-    : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {}
+                         StateMap* cond_state_map)
+    : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {}
 
 Status Conditional::AddMerge(Node* m) {
   merges_.insert(m);
@@ -343,7 +370,7 @@
     for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
       int branch_index = static_cast<int>(branch);
       TF_RETURN_IF_ERROR(
-          NodeBuilder(strings::StrCat("_Arg", arg_count),
+          NodeBuilder(absl::StrCat("_Arg", arg_count),
                       FunctionLibraryDefinition::kArgOp)
               .Attr("T", dtype)
               .Attr("index", arg_count)
@@ -397,6 +424,35 @@
   return Status::OK();
 }
 
+Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch,
+                                           Graph* graph) {
+  // Previously we had edge:
+  //   src:src_output ---- edge ----> dst:dst_input
+  // post this we have (in graph)
+  //   src:src_output --> switch<pred> --- new_edge --> dst:dst_input
+
+  // TODO(jpienaar): One could keep a map caching the extra switch nodes added
+  // to avoid adding another switch to feed a value for which a switch was
+  // already added.
+  Node* switch_node;
+  Node* src = edge->src();
+  int src_output = edge->src_output();
+  TF_RETURN_IF_ERROR(
+      NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")),
+                  "Switch")
+          .Input(src, src_output)
+          .Input(const_cast<Node*>(predicate_.node), predicate_.index)
+          .Finalize(graph, &switch_node));
+  state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src));
+  state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src));
+
+  Node* dst = edge->dst();
+  int dst_input = edge->dst_input();
+  graph->RemoveEdge(edge);
+  graph->AddEdge(switch_node, static_cast<int>(branch), dst, dst_input);
+  return AddSwitch(switch_node);
+}
+
 Status Conditional::ExtractBodies(Graph* graph) {
   VLOG(2) << "Extracting bodies for " << name();
   for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) {
@@ -405,16 +461,16 @@
   }
 
   auto find_branch = [&](const Edge* e) {
-    const auto& id = cond_state_map_->LookupId(e->src());
+    const auto& id = state_map_->LookupCondId(e->src());
     return IsSwitch(e->src()) ? BranchType(e->src_output())
-                              : cond_state_map_->FindBranchOf(id, predicate_);
+                              : state_map_->FindBranchOf(id, predicate_);
   };
 
   std::array<std::vector<Node*>, 2> stacks;
   VLOG(5) << "Merges: " << NodesToString(merges_);
   for (Node* m : merges_) {
     VLOG(5) << "For merge: " << m->DebugString() << " "
-            << cond_state_map_->CondStateToString(m);
+            << state_map_->CondStateToString(m);
     for (auto e : m->in_edges()) {
       if (e->IsControlEdge()) continue;
       BranchType branch = find_branch(e);
@@ -422,7 +478,8 @@
                    branch == BranchType::kElseBranch)
           << "Error: " << e->src()->name()
           << " is not on either then or else branch (" << Branch_Name(branch)
-          << ").";
+          << ") for predicate " << DebugString(predicate_) << " ["
+          << DebugString(state_map_->LookupCondId(e->src())) << "].";
       Node* src = e->src();
       if (IsSwitch(src)) {
         // Switch node outputs and dependencies are handled separately.
@@ -456,8 +513,8 @@
         if (IsMerge(dst)) continue;
         Node* src = e->src();
 
-        auto dst_id = cond_state_map_->LookupId(dst);
-        auto src_id = cond_state_map_->LookupId(src);
+        auto dst_id = state_map_->LookupCondId(dst);
+        auto src_id = state_map_->LookupCondId(src);
         if (dst_id != src_id) {
           if (e->IsControlEdge()) {
             external_control_outputs_.push_back(e->src());
@@ -480,8 +537,11 @@
         }
       }
 
-      // Copying incomming edges to dst node.
-      for (const Edge* e : n->in_edges()) {
+      // Copying incomming edges to dst node. Iterate over a copy of the edges
+      // as they could be mutated during iteration.
+      std::vector<const Edge*> in_edges(n->in_edges().begin(),
+                                        n->in_edges().end());
+      for (const Edge* e : in_edges) {
         Node* src = e->src();
         // Skip src/dst node.
         if (!src->IsOp()) continue;
@@ -494,8 +554,8 @@
         }
 
         // Verify input is from the same context.
-        auto src_id = cond_state_map_->LookupId(src);
-        auto dst_id = cond_state_map_->LookupId(dst);
+        auto src_id = state_map_->LookupCondId(src);
+        auto dst_id = state_map_->LookupCondId(dst);
         if (IsMerge(dst) || src_id == dst_id) {
           // TODO(jpienaar): The merge case can be more strict.
           if (node_map.at(src->id()) == nullptr) {
@@ -506,18 +566,25 @@
           external_control_inputs_.push_back(src);
         } else {
           // This shouldn't happen, this means we have an external data input
-          // not entering via a switch node. Work around this for constant
-          // nodes as some constant nodes are inserted without the required
-          // control context dominance.
+          // not entering via a switch node. Work around this by for
+          // * constant nodes copy them;
+          // * non-constant nodes, insert a switch along the edge;
           if (IsConstant(src)) {
             node_map.at(src->id()) = output->CopyNode(src);
           } else {
-            return errors::InvalidArgument(
-                "Graph contains node ", FormatNodeForError(*src),
-                " that feeds into node ", FormatNodeForError(*dst),
-                " but these nodes are in different control contexts (",
-                DebugString(src_id), " vs ", DebugString(dst_id),
-                " (detected during in edge testing)");
+            StateMap::CondState state = *dst_id;
+            state.erase(predicate_);
+            if (state_map_->GetCondId(state) == src_id) {
+              TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph));
+              continue;
+            } else {
+              return errors::InvalidArgument(
+                  "Graph contains node ", FormatNodeForError(*src),
+                  " that feeds into node ", FormatNodeForError(*dst),
+                  " but these nodes are in different control contexts (",
+                  DebugString(src_id), " vs ", DebugString(dst_id),
+                  " (detected during in edge testing)");
+            }
           }
         }
 
@@ -572,7 +639,7 @@
 Status Conditional::BuildIfNode(Graph* graph,
                                 FunctionLibraryDefinition* library) {
   VLOG(2) << "Build cond function for " << name();
-  NodeDefBuilder builder(name(), "If");
+  NodeDefBuilder builder(name(), "If", library);
   const string branch_name[] = {"else_branch", "then_branch"};
   for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
     int branch_index = static_cast<int>(branch);
@@ -580,8 +647,8 @@
     int64 id = ++sequence_num;
 
     NameAttrList body_name;
-    body_name.set_name(strings::StrCat("_functionalize_if_",
-                                       branch_name[branch_index], "_", id));
+    body_name.set_name(
+        absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id));
 
     VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index]
             << "): "
@@ -639,7 +706,8 @@
   VLOG(3) << "Build If node";
   NodeDef if_def;
   TF_RETURN_IF_ERROR(builder.Finalize(&if_def));
-  TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin()));
+  TF_ASSIGN_OR_RETURN(if_node_,
+                      parent_->AddIfNode(if_def, *merges_.begin(), predicate_));
 
   return Status::OK();
 }
@@ -699,7 +767,8 @@
 
 Status Conditional::BuildAndReplace(Graph* graph,
                                     FunctionLibraryDefinition* library) {
-  VLOG(1) << "Build If and replace merge nodes " << name();
+  VLOG(1) << "Build If and replace merge nodes "
+          << NodesToString(this->merges_);
   if (replaced_) return Status::OK();
 
   TF_RETURN_IF_ERROR(ExtractBodies(graph));
@@ -719,7 +788,6 @@
   TF_RETURN_IF_ERROR(AddInputEdges(graph));
   TF_RETURN_IF_ERROR(AddOutputEdges(graph));
   TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
-  for (Node* m : merges_) cond_state_map_->MarkDead(m);
 
   // Check that the if_node doesn't feed into itself.
   TF_RETURN_WITH_CONTEXT_IF_ERROR(
@@ -732,31 +800,7 @@
 
 string Conditional::name() const {
   CHECK(!merges_.empty());
-  return strings::StrCat((*merges_.begin())->name(), "_if");
-}
-
-bool CondStateMap::ScopeIn(CondStateMap::CondId id,
-                           CondStateMap::CondId* scope) {
-  if (id == nullptr) {
-    *scope = nullptr;
-    return true;
-  }
-  CondState state;
-  for (const CondNode& node : *id) {
-    if (node.type == CondNode::Type::kSwitch) {
-      state.push_back(node);
-    }
-    if (node.type == CondNode::Type::kMerge) {
-      if (state.empty()) {
-        return false;
-      }
-      DCHECK(state.back().type == CondNode::Type::kSwitch &&
-             state.back().branch == BranchType::kBoth);
-      state.pop_back();
-    }
-  }
-  *scope = GetUniqueId(state);
-  return true;
+  return absl::StrCat((*merges_.begin())->name(), "_if");
 }
 
 Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node,
@@ -765,25 +809,35 @@
   TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity")
                          .Input(if_node, port)
                          .Finalize(graph_, &id));
-  cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node));
+  state_map_.ResetCondId(id, state_map_.LookupCondId(if_node));
+  state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node));
   return Status::OK();
 }
 
 StatusOr<Node*> FunctionalizeCond::AddIfNode(const NodeDef& def,
-                                             const Node* replacee) {
+                                             const Node* replacee,
+                                             const OutputTensor& predicate) {
   Status status;
   Node* ret = graph_->AddNode(def, &status);
   TF_RETURN_IF_ERROR(status);
-  CondStateMap::CondState state = cond_state_map_.LookupState(replacee);
-  state.pop_back();
   VLOG(1) << "Adding If for " << replacee->name();
-  cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state));
+  StateMap::CondId id = state_map_.LookupCondId(replacee);
+  if (id) {
+    StateMap::CondState state = *id;
+    state.erase(predicate);
+    state_map_.ResetCondId(ret, state_map_.GetCondId(state));
+  } else {
+    state_map_.ResetCondId(ret, nullptr);
+  }
+
+  state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee));
+
   return ret;
 }
 
 Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) {
   VLOG(2) << "Propagating update state for " << replacee->name() << " "
-          << cond_state_map_.CondStateToString(replacee);
+          << state_map_.CondStateToString(replacee);
   // Redo topological sort as the order could have changed.
   // TODO(jpienaar): The original topological order could also be updated
   // dynamically if needed.
@@ -801,10 +855,10 @@
     if (changed.find(*it) != changed.end()) {
       // Update the node state.
       Node* n = *it;
-      CondStateMap::CondId old_state = cond_state_map_.LookupId(n);
-      cond_state_map_.ResetId(n, nullptr);
+      StateMap::CondId old_state = state_map_.LookupCondId(n);
+      state_map_.ResetCondId(n, nullptr);
       TF_RETURN_IF_ERROR(DetermineCondState(n));
-      if (cond_state_map_.LookupId(n) != old_state) {
+      if (state_map_.LookupCondId(n) != old_state) {
         for (auto out : n->out_nodes())
           if (out->IsOp()) changed.insert(out);
       }
@@ -825,127 +879,44 @@
   return BranchType::kNeither;
 }
 
-CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds(
-    CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
-  CondId lhs_scope;
-  CondId rhs_scope;
-  bool could_determine_scope = ScopeIn(lhs, &lhs_scope);
-  could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope);
-  if (!could_determine_scope) return kIncomparable;
-
-  // Returns whether a contains b.
-  auto contains = [&](CondId a, CondId b) {
-    // Handle empty states.
-    if (a == nullptr && b != nullptr) return true;
-    if (a == nullptr && b == nullptr) return true;
-    if (a != nullptr && b == nullptr) return false;
-
-    if (a->size() > b->size()) return false;
-    auto a_it = a->begin();
-    auto b_it = b->begin();
-    while (a_it != a->end()) {
-      if (*a_it != *b_it) {
-        if (!(a_it->predicate == b_it->predicate)) return false;
-        BranchType mb = MeetBranch(a_it->branch, b_it->branch);
-        if (mb != b_it->branch) return false;
-      }
-      ++a_it;
-      ++b_it;
-    }
-    return true;
-  };
-
-  bool lhs_contains_rhs = contains(lhs_scope, rhs_scope);
-  bool rhs_contains_lhs = contains(rhs_scope, lhs_scope);
-  if (lhs_contains_rhs && rhs_contains_lhs) return kEqual;
-  if (lhs_contains_rhs) return kLhsContainsRhs;
-  if (rhs_contains_lhs) return kRhsContainsLhs;
-  return kIncomparable;
-}
-
-BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
+BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const {
   if (IsEmpty(id)) return BranchType::kNeither;
-  absl::optional<BranchType> b;
   const CondState& nodes = *id;
-  for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
-    if (it->type == CondStateMap::CondNode::Type::kSwitch &&
-        it->predicate == predicate) {
-      if (b.has_value()) {
-        b = MeetBranch(*b, it->branch);
-      } else {
-        b = it->branch;
-      }
-      if (*b == BranchType::kNeither) {
-        LOG(FATAL) << "Inconsistent state for node: " << DebugString(id);
-      }
-    }
-  }
-  return b.has_value() ? *b : BranchType::kNeither;
+  auto it = nodes.find(predicate);
+  if (it == nodes.end()) return BranchType::kNeither;
+  return it->second;
 }
 
-StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
-    CondStateMap::CondId src, CondStateMap::CondId dst) {
-  VLOG(4) << "Joining src=" << DebugString(src) << " [" << src
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesNonMerge(
+    StateMap::CondId src, StateMap::CondId dst) {
+  VLOG(5) << "Joining src=" << DebugString(src) << " [" << src
           << "] and dst=" << DebugString(dst) << " [" << dst << "]";
 
-  if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src;
-  if (cond_state_map_.IsDead(dst)) return dst;
+  if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src;
+  if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst;
 
   // Nothing to do if the CondState is the same.
   if (src == dst) return src;
 
-  CondStateMap::CondId src_scope;
-  CondStateMap::CondId dst_scope;
-  if (!cond_state_map_.ScopeIn(src, &src_scope))
-    return errors::Unimplemented(
-        "Predicates that must hold for node to execute are invalid! ",
-        DebugString(src));
-  if (!cond_state_map_.ScopeIn(dst, &dst_scope))
-    return errors::Unimplemented(
-        "Predicates that must hold for node to execute are invalid! ",
-        DebugString(dst));
-
-  auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope);
-  switch (result) {
-    case CondStateMap::kIncomparable:
-      return errors::InvalidArgument(
-          "Graph contains node with inputs predicated on incompatible "
-          "predicates: ",
-          DebugString(src), " and ", DebugString(dst));
-    case CondStateMap::kEqual:
-      // If both respect the same predicates, propagate the longer constraint.
-      if ((src != nullptr && dst == nullptr) ||
-          (src != nullptr && dst != nullptr && src->size() > dst->size()))
-        return src;
-      else
-        return dst;
-    case CondStateMap::kLhsContainsRhs:
-      // src contains dst, so dst is already more restrictive.
-      return dst;
-    case CondStateMap::kRhsContainsLhs:
-      // dst contains src, so src is more restrictive.
-      return src;
+  StateMap::CondState both = *src;
+  for (const auto& kv : *dst) {
+    auto it = both.find(kv.first);
+    if (it == both.end()) {
+      both.insert(kv);
+    } else {
+      if (it->second != kv.second) {
+        return errors::InvalidArgument(
+            "Graph contains node with inputs predicated on incompatible "
+            "predicates: ",
+            DebugString(src), " and ", DebugString(dst));
+      }
+    }
   }
+  return state_map_.GetCondId(both);
 }
 
-StatusOr<CondStateMap::CondState::const_iterator>
-FindThenElseSwitchForPredicate(const OutputTensor& pred,
-                               CondStateMap::CondId id) {
-  for (auto it = id->begin(); it != id->end(); ++it) {
-    // Along every path one there can be only one instance of a then or else
-    // switch for a given predicate, so return once found.
-    if (it->type == CondStateMap::CondNode::Type::kSwitch &&
-        it->predicate == pred &&
-        (it->branch == BranchType::kThenBranch ||
-         it->branch == BranchType::kElseBranch))
-      return it;
-  }
-  return errors::Internal("Unable to find then/else branch with predicate ",
-                          DebugString(pred), " for ", DebugString(id));
-}
-
-StatusOr<CondStateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
-    CondStateMap::CondId src, CondStateMap::CondId dst) {
+StatusOr<StateMap::CondId> FunctionalizeCond::JoinCondStatesMerge(
+    Node* merge, StateMap::CondId src, StateMap::CondId dst) {
   // Determine the flow state when joining two states for a merge
   // node. Combining the two states for a merge node is effectively performing a
   // disjunction of the states along the different input edges. For a merge that
@@ -956,91 +927,56 @@
   // followed by s(p, both).
   VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
           << DebugString(dst);
-  if (cond_state_map_.IsEmpty(dst)) return src;
+  if (state_map_.IsEmpty(dst)) return src;
 
-  if (cond_state_map_.IsDead(src)) return src;
-  if (cond_state_map_.IsDead(dst)) return dst;
+  if (state_map_.IsDead(src)) return src;
+  if (state_map_.IsDead(dst)) return dst;
 
-  CondStateMap::CondId src_scope;
-  CondStateMap::CondId dst_scope;
-  if (!cond_state_map_.ScopeIn(src, &src_scope))
-    return errors::Unimplemented(
-        "Predicates that must hold for node to execute are invalid! ",
-        DebugString(src));
-  if (!cond_state_map_.ScopeIn(dst, &dst_scope))
-    return errors::Unimplemented(
-        "Predicates that must hold for node to execute are invalid! ",
-        DebugString(dst));
+  std::vector<StateMap::CondState::value_type> diff;
+  StateMap::CondState merged;
+  std::set_symmetric_difference(src->begin(), src->end(), dst->begin(),
+                                dst->end(), std::back_inserter(diff),
+                                CondStateLess());
+  std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(),
+                        std::inserter(merged, merged.begin()), CondStateLess());
 
-  TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr)
-      << "Illegal merge inputs from outer scope: src=" << DebugString(src)
-      << " dst=" << DebugString(dst);
-  auto src_it = src_scope->begin();
-  auto dst_it = dst_scope->begin();
-
-  // Find branch divergent condition.
-  OutputTensor pred;
-  while (src_it != src_scope->end() && dst_it != dst_scope->end()) {
-    if (*src_it != *dst_it) {
-      VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and "
-              << DebugString(*dst_it);
-      if (!(src_it->predicate == dst_it->predicate)) {
-        return errors::InvalidArgument(
-            "Unable to find common predicate which holds for one input "
-            "but not the other of the merge node.");
-      }
-      pred = src_it->predicate;
-      break;
-    }
-    ++src_it;
-    ++dst_it;
-  }
-
-  if (pred.node == nullptr)
-    return errors::InvalidArgument("Unable to determine predicate for merge.");
-
-  TF_ASSIGN_OR_RETURN(auto div_src_it,
-                      FindThenElseSwitchForPredicate(pred, src));
-  TF_ASSIGN_OR_RETURN(auto div_dst_it,
-                      FindThenElseSwitchForPredicate(pred, dst));
-  TF_RET_CHECK(*div_src_it != *div_dst_it);
-
-  CondStateMap::CondState result;
-  // Populate result with the longest/most restrictive path up to the divergent
-  // node. For example, if the one input is `[switch(pred:0, then)]` and the
-  // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created
-  // in gradient of cond test), then the resultant state here should be
-  // `[switch(pred:0, both), merge, switch(pred:0, both)]`.
-  if (std::distance(src->begin(), div_src_it) >
-      std::distance(dst->begin(), div_dst_it)) {
-    result.assign(src->begin(), std::next(div_src_it));
+  // Update mapping from merge node to predicate.
+  if (diff.size() == 2) {
+    auto pred = diff[0].first;
+    bool different_branches = (diff[0].second != diff[1].second) &&
+                              (diff[0].second == BranchType::kThenBranch ||
+                               diff[0].second == BranchType::kElseBranch) &&
+                              (diff[1].second == BranchType::kThenBranch ||
+                               diff[1].second == BranchType::kElseBranch);
+    if (!(pred == diff[1].first) || !different_branches)
+      return errors::InvalidArgument(
+          "Unable to determine predicate for merge node");
+    merge_to_predicate_[merge] = pred;
   } else {
-    result.assign(dst->begin(), std::next(div_dst_it));
+    return errors::InvalidArgument(
+        "Merge of two inputs that differ on more than one predicate ",
+        DebugString(src), " and ", DebugString(dst));
   }
-  result.back().branch = BranchType::kBoth;
-  return cond_state_map_.GetUniqueId(result);
+
+  return state_map_.GetCondId(merged);
 }
 
-CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
+StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) {
   Node* src = e->src();
-  CondStateMap::CondId id = cond_state_map_.LookupId(e->src());
-  if (IsMerge(src)) {
-    CondStateMap::CondState state;
-    if (id != nullptr) state = *id;
-    state.emplace_back(CondStateMap::CondNode::Type::kMerge);
-    return cond_state_map_.GetUniqueId(state);
-  }
+  StateMap::CondId id = state_map_.LookupCondId(e->src());
+
+  // Dead nodes only propagate dead state.
+  if (state_map_.IsDead(id)) return id;
+
   if (IsSwitch(src)) {
-    CondStateMap::CondState state;
+    StateMap::CondState state;
     if (id != nullptr) state = *id;
-    if (e->IsControlEdge()) {
-      state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
-                         BranchType::kBoth);
-    } else {
-      state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src,
-                         BranchType(e->src_output()));
+    OutputTensor predicate;
+    TF_CHECK_OK(GetSwitchPredicate(*src, &predicate));
+    if (!e->IsControlEdge()) {
+      state[predicate] = BranchType(e->src_output());
     }
-    return cond_state_map_.GetUniqueId(state);
+    return state_map_.GetCondId(state);
   }
   return id;
 }
@@ -1049,22 +985,21 @@
   // Only Merge nodes with two inputs are supported, but if this is a redundant
   // merge, then the dead edge may already have been removed (if due to a
   // switch) and so the input count would be incorrect.
-  if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst)))
-    return Status::OK();
+  if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK();
 
   int data_inputs = 0;
   for (auto e : dst->in_edges()) {
     Node* src = e->src();
     VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " "
-            << cond_state_map_.CondStateToString(src);
+            << state_map_.CondStateToString(src);
     if (!src->IsOp()) continue;
     if (!e->IsControlEdge()) ++data_inputs;
 
-    CondStateMap::CondId prop = StateAlongEdge(e);
-    auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst));
+    StateMap::CondId prop = StateAlongEdge(e);
+    auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst));
     TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
                                     FormatNodeForError(*dst));
-    cond_state_map_.ResetId(dst, id_or.ValueOrDie());
+    state_map_.ResetCondId(dst, id_or.ValueOrDie());
   }
 
   // Incomplete Merge nodes are not supported.
@@ -1076,27 +1011,20 @@
   return Status::OK();
 }
 
-Status FunctionalizeCond::DetermineCondState(Node* dst) {
-  // The logic for the merge and non-merge case differ: for non-merge it is
-  // the most restrictive CondState, while for merge nodes the
-  // resultant state is less restrictive than either.
-  if (IsMerge(dst)) {
-    TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst));
-  } else {
-    // Handle non-merge join.
-    for (auto e : dst->in_edges()) {
-      VLOG(5) << "Processing forward flow for: " << e->DebugString() << " "
-              << cond_state_map_.CondStateToString(dst);
-      Node* src = e->src();
-      if (!src->IsOp()) continue;
+Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) {
+  // Handle non-merge join.
+  for (auto e : dst->in_edges()) {
+    VLOG(4) << "Processing forward flow for: " << e->DebugString() << " "
+            << state_map_.CondStateToString(dst);
+    Node* src = e->src();
+    if (!src->IsOp()) continue;
 
-      // Joining the state between the current and propagated state.
-      CondStateMap::CondId prop = StateAlongEdge(e);
-      auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst));
-      TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
-                                      FormatNodeForError(*dst));
-      cond_state_map_.ResetId(dst, id_or.ValueOrDie());
-    }
+    // Joining the state between the current and propagated state.
+    StateMap::CondId prop = StateAlongEdge(e);
+    auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst));
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
+                                    FormatNodeForError(*dst));
+    state_map_.ResetCondId(dst, id_or.ValueOrDie());
   }
   return Status::OK();
 }
@@ -1104,8 +1032,7 @@
 Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
   // Handle redundant merge nodes. A merge node is considered redundant if
   // one input edge is dead while the other has a value.
-  if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node)))
-    return Status::OK();
+  if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK();
 
   const Edge* non_dead_edge = nullptr;
   for (auto e : node->in_edges()) {
@@ -1113,8 +1040,8 @@
     Node* src = e->src();
 
     // Handle merge with dead state.
-    const auto& src_id = cond_state_map_.LookupId(src);
-    if (!cond_state_map_.IsDead(src_id)) {
+    const auto& src_id = state_map_.LookupCondId(src);
+    if (!state_map_.IsDead(src_id)) {
       non_dead_edge = e;
       break;
     }
@@ -1124,8 +1051,7 @@
     return errors::InvalidArgument("Merge node ", FormatNodeForError(*node),
                                    " has no non-dead inputs.");
   }
-  cond_state_map_.MarkDead(node);
-  delete_nodes_.push_back(node->id());
+  state_map_.MarkDead(node);
   VLOG(5) << "removing redundant merge: " << node->name();
   while (!node->out_edges().empty()) {
     const Edge* oe = *node->out_edges().begin();
@@ -1149,16 +1075,33 @@
   // along one. The checking of predicate is based on the exact predicate
   // (rather than boolean equivalence) and aimed at redundant switches as
   // currently generated by gradient code.
+  StateMap::CondId dst_id = state_map_.LookupCondId(node);
+  if (state_map_.IsDead(dst_id)) return Status::OK();
+
+  BranchType b;
   OutputTensor pred;
   TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred));
-  auto dst_id = cond_state_map_.LookupId(node);
-  BranchType b = cond_state_map_.FindBranchOf(dst_id, pred);
-  // Determine if we are already on a branch where the switch predicate is
-  // true/false.
-  if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
-    return Status::OK();
 
-  VLOG(5) << "Redundant switch " << node->name();
+  // Determine if we are already on a branch where the switch predicate is
+  // true/false. Consider both the data and predicate to determine if the
+  // node is redundant (skipping over identity node).
+  b = state_map_.FindBranchOf(dst_id, pred);
+  if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) {
+    OutputTensor val;
+    const Edge* e;
+    TF_RETURN_IF_ERROR(node->input_edge(0, &e));
+    val = OutputTensor(e->src(), e->src_output());
+    while (IsIdentity(val.node)) {
+      TF_RETURN_IF_ERROR(val.node->input_edge(0, &e));
+      val = OutputTensor(e->src(), e->src_output());
+    }
+    b = state_map_.FindBranchOf(dst_id, val);
+    if (b != BranchType::kThenBranch && b != BranchType::kElseBranch)
+      return Status::OK();
+  }
+
+  VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " "
+          << DebugString(dst_id);
   const Edge* value_edge;
   TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge));
   Node* val_node = value_edge->src();
@@ -1171,20 +1114,19 @@
     graph_->RemoveEdge(e);
     if (switch_branch == Graph::kControlSlot) {
       if (IsMerge(dst_node)) {
-        auto id_or =
-            JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node));
+        auto id_or = JoinCondStatesMerge(dst_node, dst_id,
+                                         state_map_.LookupCondId(dst_node));
         TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ",
                                         FormatNodeForError(*dst_node));
-        cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+        state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
       } else {
         auto id_or =
-            JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node));
+            JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node));
         TF_RETURN_IF_ERROR(id_or.status());
-        cond_state_map_.ResetId(dst_node, id_or.ValueOrDie());
+        state_map_.ResetCondId(dst_node, id_or.ValueOrDie());
       }
     } else if (BranchType(switch_branch) != b) {
-      cond_state_map_.MarkDead(dst_node);
-      delete_nodes_.push_back(dst_node->id());
+      state_map_.MarkDead(dst_node);
       continue;
     }
     graph_->AddEdge(
@@ -1195,37 +1137,103 @@
   return Status::OK();
 }
 
-Status FunctionalizeCond::DetermineCondStates(
-    std::vector<Node*> rev_topo_order) {
+Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
   // The state that is propagated along the given edge.
   for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) {
     Node* dst = *it;
     TF_RETURN_IF_ERROR(DetermineCondState(dst));
+    TF_RETURN_IF_ERROR(DetermineAncestorState(dst));
     if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst));
     if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst));
 
-    VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst);
+    VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
+            << " @ " << state_map_.AncestorStateToString(dst);
+    if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
   }
   return Status::OK();
 }
 
-void FunctionalizeCond::DeleteReachableNodes() {
+Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
+  StateMap::AncestorId id = nullptr;
+  StateMap::AncestorState state;
+
+  auto insert = [&](StateMap::AncestorId id, Node* src) {
+    auto other_id = state_map_.LookupAncestorId(src);
+    if (other_id != id && other_id != nullptr) {
+      state.insert(other_id->begin(), other_id->end());
+    }
+    if (IsSwitch(src) || IsMerge(src)) {
+      state.insert(src);
+    }
+    return state_map_.GetAncestorId(state);
+  };
+
+  // Compute the union of all the switch/merge nodes that affects the input of
+  // dst.
+  for (auto e : dst->in_edges()) {
+    Node* src = e->src();
+    id = insert(id, src);
+  }
+  state_map_.ResetAncestorId(dst, id);
+  return Status::OK();
+}
+
+void FunctionalizeCond::DeleteReachableAndDeadNodes(
+    const std::vector<int>& switch_ids, const std::vector<Node*>& merge_order) {
   // Delete all nodes that have been extracted or are reachable from
   // deleted/dead nodes. The input and outgoing edges should have already been
   // removed.
+  std::deque<int> delete_nodes;
   std::vector<bool> deleted(graph_->num_node_ids(), false);
   // Don't try to delete source or sink nodes.
   deleted[graph_->kSourceId] = true;
   deleted[graph_->kSinkId] = true;
-  while (!delete_nodes_.empty()) {
-    int d_id = delete_nodes_.front();
-    delete_nodes_.pop_front();
+
+  // All remaining Switch nodes are not reachable from a Merge node and
+  // removed. This is to account for dead Switch nodes.
+  for (int s_id : switch_ids) {
+    Node* s = graph_->FindNodeId(s_id);
+    if (s == nullptr) continue;
+    for (const Edge* e : s->out_edges()) {
+      // Control outputs of switch nodes (which are unconditionally executed if
+      // the switch is) are not removed as they need not be part of a
+      // conditional.
+      if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+    }
+    deleted[s_id] = true;
+    graph_->RemoveNode(s);
+  }
+
+  // All merge nodes should have been transformed at this point and we remove
+  // them from the graph here.
+  for (Node* m : merge_order) {
+    for (const Edge* e : m->out_edges()) {
+      // Similar to control outputs of switch nodes don't remove control
+      // outputs of merge nodes.
+      // TODO(jpienaar): Check cases where output edges still exist here vs
+      // being removed in AddOutputEdges.
+      if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+    }
+    deleted[m->id()] = true;
+    graph_->RemoveNode(m);
+  }
+
+  // Enqueue all the dead nodes.
+  for (Node* n : graph_->nodes()) {
+    if (state_map_.IsDead(state_map_.LookupCondId(n))) {
+      delete_nodes.push_back(n->id());
+    }
+  }
+
+  while (!delete_nodes.empty()) {
+    int d_id = delete_nodes.front();
+    delete_nodes.pop_front();
     if (deleted[d_id]) continue;
     Node* d = graph_->FindNodeId(d_id);
     // Switch and Merge nodes could have been deleted already.
     if (d == nullptr) continue;
     for (const Edge* e : d->out_edges()) {
-      delete_nodes_.push_back(e->dst()->id());
+      delete_nodes.push_back(e->dst()->id());
     }
     deleted[d_id] = true;
     graph_->RemoveNode(d);
@@ -1239,16 +1247,8 @@
   inner_to_outer_merge_order.reserve(merge_order->size());
   for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) {
     Node* merge = *it;
-    CondStateMap::CondId id = cond_state_map_.LookupId(merge);
-    int depth = 0;
-    for (auto cond_node_it = id->begin(); cond_node_it != id->end();
-         ++cond_node_it) {
-      if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch &&
-          (cond_node_it->branch == BranchType::kThenBranch ||
-           cond_node_it->branch == BranchType::kElseBranch)) {
-        ++depth;
-      }
-    }
+    StateMap::CondId id = state_map_.LookupCondId(merge);
+    int depth = id != nullptr ? id->size() : 0;
     inner_to_outer_merge_order.emplace_back(depth, merge);
   }
   std::stable_sort(
@@ -1271,10 +1271,10 @@
   // determine deeper equivalence). We shall refer to this structure as the
   // CondState;
   // 3. Sort the merge nodes by nesting depth;
-  // 4. Extract merge nodes together that have the same CondState and whose
-  // input nodes have the same state from the innermost to the outermost into
-  // IfOps; Note: In the above only nodes paths that converge to a merge node
-  // will be considered for removal.
+  // 4. Extract merge nodes together that have the same CondState and
+  // AncestorState from the innermost to the outermost into IfOps;
+  // Note: In the above only nodes that feed into a merge node will be
+  // considered for functionalization.
 
   // Perform a DFS over the graph and
   // * Determine the reverse topological order of the nodes (there should be no
@@ -1306,50 +1306,46 @@
     return Status::OK();
   }
 
-  TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order)));
-
-  if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+  TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
+  if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
 
   // Sort the merge nodes from innermost outwards.
   SortMergeNodes(&merge_order);
 
-  // Extract from innermost out.
-  for (auto it = merge_order.begin(); it != merge_order.end(); ++it) {
-    Node* merge = *it;
-    auto id = cond_state_map_.LookupId(merge);
-    if (cond_state_map_.IsDead(id)) continue;
+  // Cluster merge nodes by CondId and AncestorId in order of nesting.
+  using ClusterPair = std::pair<StateMap::CondId, StateMap::AncestorId>;
+  std::deque<std::vector<Node*>> merge_clusters;
+  std::map<ClusterPair, int> merge_cluster_index;
+  for (Node* merge : merge_order) {
+    auto cond_id = state_map_.LookupCondId(merge);
+    if (state_map_.IsDead(cond_id)) continue;
 
-    // Construct a Conditional with the predicate of the merge (which is the
-    // last entry of the CondState for the merge) and this as parent.
-    DCHECK(id->back().predicate.node != nullptr);
-    Conditional cond(id->back().predicate, this, &cond_state_map_);
-    TF_RETURN_IF_ERROR(cond.AddMerge(merge));
-
-    // Find all merge nodes with the same CondId. This is done repeatedly as
-    // the CondId can change due replaced conditionals. E.g., the one branch
-    // could previously have had a conditional nested in it, and so would have
-    // had CondState with sub-state [switch(p,b),m] (where p is some predicate),
-    // post removing the nested conditional that sub-state would no longer be
-    // path of the propagated state along that path.
-    auto end = merge_order.end();
-    for (auto merge_candidate_it = std::next(it); merge_candidate_it != end;
-         ++merge_candidate_it) {
-      auto merge_candidate_it_id =
-          cond_state_map_.LookupId(*merge_candidate_it);
-      if (merge_candidate_it_id != id) continue;
-      TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it));
+    ClusterPair key =
+        std::make_pair(cond_id, state_map_.LookupAncestorId(merge));
+    auto idx = merge_cluster_index.find(key);
+    if (idx == merge_cluster_index.end()) {
+      merge_cluster_index[key] = merge_clusters.size();
+      merge_clusters.push_back({merge});
+    } else {
+      merge_clusters[idx->second].emplace_back(merge);
     }
+  }
 
+  // Extract the conditionals from inner most to outer most. Extracting from
+  // innermost to outermost enables the extraction pass to stop once it
+  // encounters a Switch node instead of having to keep track of Switch/Merge
+  // nodes seen.
+  for (const auto& cluster : merge_clusters) {
+    // Construct a Conditional with the predicate of the merge.
+    Conditional cond(merge_to_predicate_.at(cluster.front()), this,
+                     &state_map_);
+    for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge));
     TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_));
 
     if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
   }
 
-  // All remaining Switch nodes are not reachable from a Merge node and
-  // removed. This is to account for dead Switch nodes.
-  for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
-  for (Node* m : merge_order) delete_nodes_.push_back(m->id());
-  DeleteReachableNodes();
+  DeleteReachableAndDeadNodes(switch_ids, merge_order);
 
   return Status::OK();
 }
@@ -1359,11 +1355,14 @@
 
   for (Node* n : graph_->nodes()) {
     n->ClearAttr(kCondGroupDebugAttr);
-    n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n));
+    n->AddAttr(kCondGroupDebugAttr,
+               absl::StrCat(state_map_.CondStateToString(n), "_",
+                            state_map_.AncestorStateToString(n)));
   }
   LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
             << dump_graph::DumpGraphToFile(
-                   strings::StrCat("functionalize_", name), *graph_, library_);
+                   absl::StrCat("functionalize_cond_", name), *graph_,
+                   library_);
 }
 
 Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 8643601..1899808 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -43,59 +43,53 @@
   kNeither = 3,
 };
 
-// CondStateMap is responsible for mapping from each graph Node to a CondState,
-// where each CondState is the array of CondNodes (corresponding to switch,
-// merge or dead states) as described below.  For efficiency, this class interns
-// the CondState, so that CondState equality comparisons are simply pointer
+// StateMap is responsible for mapping from each graph Node to
+// * a CondState, where each CondState is a map from predicate to branch (i,e.,
+//   what predicates have to hold or not hold).
+// * a AncestorState, where each AncestorState is a set of switch/merge nodes
+//   that are an ancestor of the node in the graph;
+// For efficiency, this class interns the CondState (AncestorState), so that
+// CondState (AncestorState) equality comparisons are simply pointer
 // comparisons.
-class CondStateMap {
+class StateMap {
  public:
-  explicit CondStateMap(Graph* graph);
+  explicit StateMap(Graph* graph);
 
-  // Represents an entry in the CondState. An entry can either be the
-  // switch (along with predicate), merge, or dead:
-  // * switch node indicates a node that is executed along a branch with the
-  //   given predicate - a branch can be then, else or both;
-  // * merge node indicates that the node is executed as output of a merge;
-  // * dead indicates that this node can never be executed;
-  struct CondNode {
-    enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 };
-
-    CondNode(Type type, Node* switch_node = nullptr,
-             BranchType branch = BranchType::kNeither);
-
-    string ToString() const;
-    bool operator==(const CondNode& other) const;
-    bool operator!=(const CondNode& other) const;
-
-    // Type of node.
-    Type type;
-
-    // Predicate and branch, only used when type is kSwitch.
-    OutputTensor predicate;
-    BranchType branch;
+  // Compare two OutputTensors by (node id, index).
+  struct OutputTensorLess {
+    bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const;
   };
 
-  // A node in the graph is executed when multiple conditions hold. The order
-  // represents the nesting of the predicates that hold and is used when
-  // extracting the nested conditionals.
-  using CondState = std::vector<CondNode>;
+  // A node in the graph is executed when multiple conditions hold. Keep track
+  // of the predicates that must hold for a node to execute.
+  using CondState = std::map<OutputTensor, BranchType, OutputTensorLess>;
 
   // Every unique ID is mapped to a CondState.
   using CondId = const CondState*;
 
+  // Keep track of which switch/merge node's feed into a node's values.
+  using AncestorState = std::set<Node*>;
+
+  // Every unique ID is mapped to a AncestorState.
+  using AncestorId = const AncestorState*;
+
   // Returns the CondId for a given node.
-  CondId LookupId(const Node* node) const;
+  CondId LookupCondId(const Node* node) const;
 
   // Returns the unique CondId for CondState.
-  CondId GetUniqueId(const CondState& state);
-
-  // Returns the CondState for a Node.
-  // REQUIRES: node has a non-empty CondState.
-  const CondState& LookupState(const Node* node) const;
+  CondId GetCondId(const CondState& state);
 
   // Resets the CondId for a given node.
-  void ResetId(const Node* node, CondId id);
+  void ResetCondId(const Node* node, CondId id);
+
+  // Returns the AncestorId for a given node.
+  AncestorId LookupAncestorId(const Node* node) const;
+
+  // Returns the unique AncestorId for CondState.
+  AncestorId GetAncestorId(const AncestorState& state);
+
+  // Resets the AncestorId for a given node.
+  void ResetAncestorId(const Node* node, AncestorId id);
 
   // Marks `node` as dead.
   void MarkDead(const Node* node);
@@ -103,45 +97,30 @@
   // Determine branch execution of CondState.
   BranchType FindBranchOf(CondId id, OutputTensor predicate) const;
 
-  // Enum to represent whether one cond flow state contains another.
-  enum ContainsResult {
-    kIncomparable,
-    kEqual,
-    kLhsContainsRhs,
-    kRhsContainsLhs
-  };
-
-  // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e.,
-  // [(p,t)] contains [(p,t), (r,t)].
-  ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs);
-
   // Returns textual representation of node's CondState.
   string CondStateToString(const Node* node) const;
   string CondStateToString(CondId id) const;
 
+  // Returns textual representation of node's AncestorState.
+  string AncestorStateToString(const Node* node) const;
+
   // Returns whether the cond state is the dead state.
   bool IsDead(CondId id) const;
 
   // Returns whether the cond state is the empty state.
   bool IsEmpty(CondId id) const;
 
-  // Computes the predicates that have to hold for a node to execute and returns
-  // whether it was possible to determine the predicates that must hold. `scope`
-  // is populated with these predicates. Scope differs from state in that it
-  // does not include merge and both nodes.
-  bool ScopeIn(CondId id, CondId* scope);
-
  private:
-  // Hash for CondNode and CondState.
-  struct CondHash {
-    size_t operator()(const CondNode& item) const;
-    size_t operator()(const CondState& vec) const;
+  // Hash for CondState and AncestorState.
+  struct Hash {
+    size_t operator()(const CondState& map) const;
+    size_t operator()(const AncestorState& map) const;
   };
 
   // Set to keep track of unique CondStates.
   // Pointers to the entries in the unordered set are used as identifiers:
   // unordered_set guarantees that the pointers remain the same.
-  std::unordered_set<CondState, CondHash> condstate_set_;
+  std::unordered_set<CondState, Hash> condstate_set_;
 
   // Mapping from Node id to CondId.
   std::vector<CondId> node_to_condid_map_;
@@ -150,7 +129,12 @@
   // from Node id in the original graph to the CondId, but there will be nodes
   // added to the original graph (such as If nodes) whose CondState needs to be
   // tracked too.
-  std::unordered_map<int, CondId> added_node_mapping_;
+  std::unordered_map<int, CondId> added_node_condid_mapping_;
+
+  // AncestorId variants of the CondId members.
+  std::unordered_set<AncestorState, Hash> ancestorstate_set_;
+  std::vector<AncestorId> node_to_ancestorid_map_;
+  std::unordered_map<int, AncestorId> added_node_ancestorid_mapping_;
 
   // Identifier of the dead flow state. The empty flow state is represented with
   // a nullptr.
@@ -173,7 +157,8 @@
 
   // Add a If node to the graph defined by def that will, amongst other, replace
   // replacee in the graph.
-  xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee);
+  xla::StatusOr<Node*> AddIfNode(const NodeDef& def, const Node* replacee,
+                                 const OutputTensor& predicate);
 
   // Propagates the state of a newly inserted node.
   Status PropagateUpdatedState(const Node* replacee);
@@ -185,35 +170,42 @@
   FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library);
 
   // Performs the actual cond functionalization. Iterate over groups of merge
-  // nodes (linked by common predicate & CondIds of the incomming edges),
-  // from innermost to outermost, and extract into If nodes.
+  // nodes (linked by common predicates & ancestor IDs), from innermost to
+  // outermost, and extract into If nodes.
   Status FunctionalizeInternal();
 
   // Returns the forward flow state propagated along edge `e`.
-  // This may modify cond_state_map_.
-  CondStateMap::CondId StateAlongEdge(const Edge* e);
+  // This may modify state_map_.
+  StateMap::CondId StateAlongEdge(const Edge* e);
 
-  // Determines the CondState of all the nodes in the given vector where
-  // the input is expected in reverse topological order.
-  // This populates the cond_state_map_.
-  Status DetermineCondStates(std::vector<Node*> rev_topo_order);
+  // Determines the CondState and AncestorState of all the nodes in the given
+  // vector where the input is expected in reverse topological order.
+  // This populates the state_map_.
+  Status DetermineStates(std::vector<Node*> rev_topo_order);
 
   // Determine the CondState for a given node using the incomming edges
   // to the node. Note: it is expected that this node's CondState is only
   // determined once its input's CondState is.
-  Status DetermineCondState(Node* dst);
+  Status DetermineCondState(Node* dst) {
+    if (IsMerge(dst)) return DetermineCondStateMerge(dst);
+    return DetermineCondStateNonMerge(dst);
+  }
 
   // Helper functions for DetermineCondState.
+  Status DetermineCondStateNonMerge(Node* dst);
   Status DetermineCondStateMerge(Node* dst);
 
-  // Helper functions for DetermineCondStates. Determines the dst node's
-  // CondState by joining the src and dst's CondState where either
-  // the dst node is a merge or not.
-  // These may modify cond_state_map_.
-  xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
-      CondStateMap::CondId src, CondStateMap::CondId dst);
-  xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
-      CondStateMap::CondId src, CondStateMap::CondId dst);
+  // Determines the dst node's CondState by joining the src and dst's CondState
+  // where either the dst node is a merge or not.
+  // These may modify state_map_.
+  xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* merge,
+                                                      StateMap::CondId src,
+                                                      StateMap::CondId dst);
+  xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+                                                         StateMap::CondId dst);
+
+  // Determines which switch/merge nodes are ancestors of this node.
+  Status DetermineAncestorState(Node* dst);
 
   // Checks if a merge node is redundant and if so removes it from the graph.
   Status RemoveRedundantMerge(Node* node);
@@ -225,15 +217,18 @@
   // nesting depth.
   void SortMergeNodes(std::vector<Node*>* merge_order);
 
-  // Deletes all nodes in/consumers of `delete_nodes_`.
-  void DeleteReachableNodes();
+  // Deletes all nodes in/consumers reachable from switch/merge nodes that were
+  // extracted.
+  void DeleteReachableAndDeadNodes(const std::vector<int>& switch_ids,
+                                   const std::vector<Node*>& merge_order);
 
-  // Member used to unique the CondState to a unique CondId and keep track of
-  // CondState/CondId per Node.
-  CondStateMap cond_state_map_;
+  // Member used to unique the CondState to a unique CondId (AncestorState to a
+  // unique AncestorId) and keep track of CondState/CondId
+  // (AncestorState/AncestorId) per Node.
+  StateMap state_map_;
 
-  // Nodes to be deleted.
-  std::deque<int> delete_nodes_;
+  // Mapping from merge nodes to predicate.
+  std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
 
   FunctionLibraryDefinition* library_;
   Graph* graph_;
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
index a27f889..b0aabd6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
@@ -37,28 +37,23 @@
                                                         flib_def_.get()));
   }
 
-  CondStateMap::CondId GetUniqueId(
-      const CondStateMap::CondStateMap::CondState& state) {
-    return fc_->cond_state_map_.GetUniqueId(state);
+  StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) {
+    return fc_->state_map_.GetCondId(state);
   }
 
-  xla::StatusOr<CondStateMap::CondId> JoinCondStatesNonMerge(
-      CondStateMap::CondId src, CondStateMap::CondId dst) {
+  string GetString(const StateMap::StateMap::CondId id) {
+    return fc_->state_map_.CondStateToString(id);
+  }
+
+  xla::StatusOr<StateMap::CondId> JoinCondStatesNonMerge(StateMap::CondId src,
+                                                         StateMap::CondId dst) {
     return fc_->JoinCondStatesNonMerge(src, dst);
   }
 
-  xla::StatusOr<CondStateMap::CondId> JoinCondStatesMerge(
-      CondStateMap::CondId src, CondStateMap::CondId dst) {
-    return fc_->JoinCondStatesMerge(src, dst);
-  }
-
-  bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) {
-    return fc_->cond_state_map_.ScopeIn(ff, scope);
-  }
-
-  CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds(
-      CondStateMap::CondId lhs, CondStateMap::CondId rhs) {
-    return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs);
+  xla::StatusOr<StateMap::CondId> JoinCondStatesMerge(Node* n,
+                                                      StateMap::CondId src,
+                                                      StateMap::CondId dst) {
+    return fc_->JoinCondStatesMerge(n, src, dst);
   }
 
   FunctionDefLibrary fdef_lib_;
@@ -69,50 +64,6 @@
 
 namespace {
 
-TEST_F(FunctionalizeCondTest, ScopeIn) {
-  Tensor pred_tensor(DT_BOOL, TensorShape());
-  pred_tensor.flat<bool>().setZero();
-  Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred");
-  Tensor val_tensor(DT_INT32, TensorShape());
-  val_tensor.flat<int>().setZero();
-  Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
-  Node* s = test::graph::Switch(graph_.get(), val, pred);
-
-  {
-    CondStateMap::CondStateMap::CondState ss;
-    ss.emplace_back(CondStateMap::CondNode(
-        CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
-    CondStateMap::CondId id = GetUniqueId(ss);
-    CondStateMap::CondId scope;
-    ASSERT_TRUE(ScopeIn(id, &scope));
-    ASSERT_TRUE(id == scope);
-  }
-
-  CondStateMap::CondState empty;
-  {
-    CondStateMap::CondState ss;
-    ss.emplace_back(CondStateMap::CondNode(
-        CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
-    ss.emplace_back(
-        CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
-    CondStateMap::CondId id = GetUniqueId(ss);
-    CondStateMap::CondId scope_1;
-    ASSERT_TRUE(ScopeIn(id, &scope_1));
-    ASSERT_TRUE(scope_1 == GetUniqueId(empty));
-    ASSERT_TRUE(id != scope_1);
-
-    ss.clear();
-    ss.emplace_back(CondStateMap::CondNode(
-        CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth));
-    id = GetUniqueId(ss);
-    CondStateMap::CondId scope_2;
-    ASSERT_TRUE(ScopeIn(id, &scope_2));
-
-    ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) ==
-                CondStateMap::ContainsResult::kLhsContainsRhs);
-  }
-}
-
 TEST_F(FunctionalizeCondTest, JoinCondStates) {
   Tensor pred_tensor(DT_BOOL, TensorShape());
   pred_tensor.flat<bool>().setZero();
@@ -120,22 +71,18 @@
   Tensor val_tensor(DT_INT32, TensorShape());
   val_tensor.flat<int>().setZero();
   Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
-  Node* s = test::graph::Switch(graph_.get(), val, pred);
+  Node* m = test::graph::Merge(graph_.get(), val, val);
 
-  CondStateMap::CondId empty = GetUniqueId({});
-
-  CondStateMap::CondId then_branch;
+  StateMap::CondId then_branch;
   {
-    CondStateMap::CondState ss;
-    ss.emplace_back(CondStateMap::CondNode(
-        CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch));
+    StateMap::CondState ss;
+    ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch));
     then_branch = GetUniqueId(ss);
   }
-  CondStateMap::CondId else_branch;
+  StateMap::CondId else_branch;
   {
-    CondStateMap::CondState ss;
-    ss.emplace_back(CondStateMap::CondNode(
-        CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch));
+    StateMap::CondState ss;
+    ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch));
     else_branch = GetUniqueId(ss);
   }
 
@@ -144,39 +91,14 @@
   EXPECT_TRUE(errors::IsInvalidArgument(status));
 
   // Merge between then and else branch.
-  auto joined_or = JoinCondStatesMerge(then_branch, else_branch);
+  auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch);
   TF_EXPECT_OK(joined_or.status());
-  CondStateMap::CondId joined = joined_or.ValueOrDie();
+  StateMap::CondId joined = joined_or.ValueOrDie();
 
   // Merge between then branch and both branch.
   auto t = JoinCondStatesNonMerge(then_branch, joined);
   // Note: this is OK in terms of constraint predication, but
   TF_EXPECT_OK(t.status());
-
-  // Post merge the propagated forward flow state has an additional merge.
-  CondStateMap::CondId post_merge;
-  {
-    CondStateMap::CondState ss;
-    ss = *joined;
-    ss.emplace_back(
-        CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge));
-    post_merge = GetUniqueId(ss);
-  }
-
-  t = JoinCondStatesNonMerge(post_merge, joined);
-  TF_EXPECT_OK(t.status());
-  EXPECT_TRUE(joined == t.ValueOrDie());
-
-  // No predicate that results in two paths predicated on different conditions
-  // merge.
-  t = JoinCondStatesMerge(post_merge, joined);
-  EXPECT_FALSE(t.ok());
-
-  // Post the merge we are effectively in the root scope and merging should
-  // result in the more restrictive post merge state.
-  t = JoinCondStatesNonMerge(post_merge, empty);
-  TF_EXPECT_OK(t.status());
-  EXPECT_TRUE(post_merge == t.ValueOrDie());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 5932be4..f792c52 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,16 @@
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
 #include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
 
 namespace tensorflow {
 
@@ -68,4 +73,146 @@
   return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
 }
 
+Status FunctionalizeControlFlowForFunction(
+    const string& func_name, const string& new_func_name,
+    const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+    FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+    std::map<string, string>* canonicalized_name_to_new_name) {
+  // Convert the function to Graph.
+  FunctionLibraryRuntime::Handle handle;
+  TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
+  Status ret_status = Status::OK();
+  auto cleanup_handle = gtl::MakeCleanup([&]() {
+    auto s = flr->ReleaseHandle(handle);
+    if (!s.ok()) {
+      ret_status.Update(s);
+    }
+  });
+  const FunctionBody* body = flr->GetFunctionBody(handle);
+  const FunctionDef& fdef = body->fdef;
+
+  // If any node has associated functions, functionalize them first.
+  // Gather nodes with associated functions first, because rewriting those nodes
+  // might involve node deletion/addition. Avoid modifying nodes while iterating
+  // it.
+  std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
+      nodes_to_associated_functions;
+  for (auto* n : body->graph->nodes()) {
+    auto associated_functions = GetAssociatedFunctions(*n, flr);
+    if (!associated_functions.empty()) {
+      nodes_to_associated_functions.push_back({n, associated_functions});
+    }
+  }
+  for (auto iter : nodes_to_associated_functions) {
+    Node* n = iter.first;
+    auto associated_functions = iter.second;
+    for (auto& associated_function : associated_functions) {
+      string name = associated_function.func_name();
+      string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+      auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
+      string new_name;
+      if (iter != canonicalized_name_to_new_name->end()) {
+        // If we already functionalized this function, skip functionalization
+        // but still rewrite the node.
+        new_name = iter->second;
+      } else {
+        new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+        TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+            name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+        (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+      }
+      // Notice that if "n" is a function call, RewriteAssociatedFunction() will
+      // delete it and create a new node instead, making "n" an invalid pointer.
+      // That's fine because in that case, associated_functions will only have
+      // one member and the loop will only run once.
+      TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+          body->graph, n, fld, associated_function, new_name));
+    }
+  }
+
+  // Functionalize the function body.
+  if (VLOG_IS_ON(4)) {
+    dump_graph::DumpGraphToFile(
+        absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+        *body->graph, fld);
+  }
+  TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+  if (VLOG_IS_ON(4)) {
+    dump_graph::DumpGraphToFile(
+        absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+        *body->graph, fld);
+  }
+  FunctionDef functionalized_fdef;
+  TF_RETURN_IF_ERROR(
+      GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
+
+  // Copy signature and ret from original FunctionDef.
+  *functionalized_fdef.mutable_signature() = fdef.signature();
+  *functionalized_fdef.mutable_ret() = fdef.ret();
+  functionalized_fdef.mutable_signature()->set_name(new_func_name);
+
+  // Add rewritten FunctionDef into library.
+  if (func_name == new_func_name) {
+    VLOG(2) << "Replacing function " << func_name;
+    TF_RETURN_IF_ERROR(
+        fld->ReplaceFunction(new_func_name, functionalized_fdef));
+  } else {
+    VLOG(2) << "Adding function " << new_func_name;
+    TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+  }
+
+  return ret_status;
+}
+
+Status FunctionalizeControlFlowPass::Run(
+    const GraphOptimizationPassOptions& options) {
+  Graph* graph = options.graph->get();
+  if (VLOG_IS_ON(4)) {
+    dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph,
+                                options.flib_def);
+  }
+  std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+      new ProcessFunctionLibraryRuntime(
+          /*device_mgr=*/nullptr, options.session_options->env,
+          TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
+  FunctionLibraryRuntime* flr =
+      pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+  // Find XLA compile ops and its corresponding FunctionDef.
+  static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
+      new std::map<string, string>{
+          {"TPUCompile", "function"},
+          {"XlaLaunch", "function"},
+      };
+  std::map<string, string> canonicalized_name_to_new_name;
+  for (Node* n : graph->nodes()) {
+    auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
+    if (it == kNodeTypeToFunctionAttrMapping->end()) {
+      continue;
+    }
+    const string func_attr = it->second;
+    if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
+        kNodeTypeToFunctionAttrMapping->end()) {
+      NameAttrList func;
+      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
+      VLOG(2) << "Graph has node " << n->type_string()
+              << ". Corresponding function: " << func.name();
+      string new_func_name = options.flib_def->UniqueFunctionName(
+          absl::StrCat(func.name(), "_f15n_"));
+      TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+          func.name(), new_func_name, func.attr(), options.flib_def, flr,
+          &canonicalized_name_to_new_name));
+      n->ClearAttr(func_attr);
+      func.set_name(new_func_name);
+      n->AddAttr(func_attr, func);
+    }
+  }
+
+  if (VLOG_IS_ON(4)) {
+    dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
+                                options.flib_def);
+  }
+  return Status::OK();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index 55600f2..ba99205 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
 
 #include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/graph/graph.h"
 
@@ -32,6 +33,14 @@
                                 Graph* graph,
                                 FunctionLibraryDefinition* library);
 
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (If/While).
+class FunctionalizeControlFlowPass : public GraphOptimizationPass {
+ public:
+  Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
new file mode 100644
index 0000000..a10a9d0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
+
+namespace tensorflow {
+
+// This pass is required for some AOT backends and all JIT backends, so this
+// file exists as a separate lib and will be linked to both AOT and JIT.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27,
+                      FunctionalizeControlFlowPass);
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index c068a41..c3841f9 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/control_flow_ops_internal.h"
 #include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
@@ -112,16 +113,12 @@
     auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
     auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
     auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
-    auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
-                            std::initializer_list<Input>{less, y, x}, then_fn,
-                            else_fn, {DT_INT32});
+    auto if_op = ops::If(scope.WithOpName(op_name), less,
+                         std::initializer_list<Input>{less, y, x}, {DT_INT32},
+                         then_fn, else_fn);
     auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
     GraphDef expected;
     TF_EXPECT_OK(scope.ToGraphDef(&expected));
-    // TODO(jpienaar): Create wrapper for IfOp.
-    for (NodeDef& n : *expected.mutable_node()) {
-      if (n.op() == "XlaIf") n.set_op("If");
-    }
     TF_EXPECT_GRAPH_EQ(expected, graph_def);
   }
 
@@ -177,7 +174,7 @@
 Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
                             NameAttrList* body) {
   for (const NodeDef& node : graph.node()) {
-    if (node.op() == "XlaWhile") {
+    if (node.op() == "While") {
       const NameAttrList* result;
       TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
       *cond = *result;
@@ -186,7 +183,7 @@
       return Status::OK();
     }
   }
-  return errors::NotFound("No XlaWhile node found in graph");
+  return errors::NotFound("No While node found in graph");
 }
 
 // Graph:
@@ -255,8 +252,8 @@
     Scope scope = Scope::NewRootScope().ExitOnError();
     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
     auto while_op =
-        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
-                      std::initializer_list<Input>{source}, cond_fn, body_fn);
+        ops::While(scope.WithOpName("while/LoopCond"),
+                   std::initializer_list<Input>{source}, cond_fn, body_fn);
     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
     GraphDef expected;
     TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -392,8 +389,8 @@
     Scope scope = Scope::NewRootScope().ExitOnError();
     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
     auto while_op =
-        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
-                      std::initializer_list<Input>{source}, cond_fn, body_fn);
+        ops::While(scope.WithOpName("while/LoopCond"),
+                   std::initializer_list<Input>{source}, cond_fn, body_fn);
     GraphDef expected;
     TF_ASSERT_OK(scope.ToGraphDef(&expected));
     TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -483,8 +480,8 @@
     Scope scope = Scope::NewRootScope().ExitOnError();
     auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
     auto while_op =
-        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
-                      std::initializer_list<Input>{source}, cond_fn, body_fn);
+        ops::While(scope.WithOpName("while/LoopCond"),
+                   std::initializer_list<Input>{source}, cond_fn, body_fn);
     GraphDef expected;
     TF_EXPECT_OK(scope.ToGraphDef(&expected));
     TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -625,8 +622,8 @@
     auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
     auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
     auto while_op =
-        ops::XlaWhile(scope.WithOpName("while/LoopCond"),
-                      std::initializer_list<Input>{x, y}, cond_fn, body_fn);
+        ops::While(scope.WithOpName("while/LoopCond"),
+                   std::initializer_list<Input>{x, y}, cond_fn, body_fn);
     auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
     auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
     GraphDef expected;
@@ -864,9 +861,9 @@
 
     auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
 
-    auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
-                                  std::initializer_list<Input>{zero, y, x, var},
-                                  outer_cond_fn, outer_body_fn);
+    auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
+                               std::initializer_list<Input>{zero, y, x, var},
+                               outer_cond_fn, outer_body_fn);
     auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
     GraphDef expected;
     TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -921,9 +918,9 @@
     auto one_j = ops::Const<int32>(
         scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
     auto while_op =
-        ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
-                      std::initializer_list<Input>{one_j, arg1, arg2, arg3},
-                      inner_cond_fn, inner_body_fn);
+        ops::While(scope.WithOpName("outer/LoopCond_1"),
+                   std::initializer_list<Input>{one_j, arg1, arg2, arg3},
+                   inner_cond_fn, inner_body_fn);
 
     auto one_outer = ops::Const<int32>(
         scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
index 924fcdd..54cebc6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc
@@ -42,7 +42,7 @@
   const char* const kRetValOp = "_Retval";
   NodeDef ret_def;
   ret_def.set_op(kRetValOp);
-  ret_def.set_name(strings::StrCat(kRetValOp, index));
+  ret_def.set_name(absl::StrCat(kRetValOp, index));
   AddNodeAttr("T", type, &ret_def);
   AddNodeAttr("index", index, &ret_def);
   return AddNodeDefToGraph(ret_def, graph);
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
index 61940e3..582b49d 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h
@@ -43,13 +43,12 @@
 // Returns a textual representation of the names of the nodes in the input.
 template <typename T>
 string NodesToString(const T& nodes) {
-  return strings::StrCat("{",
-                         absl::StrJoin(nodes, ",",
-                                       [](string* output, const Node* node) {
-                                         strings::StrAppend(output,
-                                                            node->name());
-                                       }),
-                         "}");
+  return absl::StrCat("{",
+                      absl::StrJoin(nodes, ",",
+                                    [](string* output, const Node* node) {
+                                      absl::StrAppend(output, node->name());
+                                    }),
+                      "}");
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 6e3c4b0..7c3ad44 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,6 +25,7 @@
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/jit/union_find.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
 #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -34,6 +35,7 @@
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 
 namespace tensorflow {
 namespace {
@@ -132,7 +134,7 @@
 StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
   const char* const kArgOp = "_Arg";
   NodeDef arg_def;
-  NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
+  NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp);
   builder.Attr("T", type);
   builder.Attr("index", index);
   TF_RETURN_IF_ERROR(builder.Finalize(&arg_def));
@@ -473,12 +475,19 @@
     }
   }
 
-  // Builds the condition and body functions.
+  // Builds the condition and body functions. Notice that we call
+  // FunctionalizeCond() on cond_graph and body_graph because we might have
+  // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
+  // before they are encapsulated in FunctionDef.
   std::unique_ptr<Graph> cond_graph;
   TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+  FixupSourceAndSinkEdges(cond_graph.get());
+  TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
   DataTypeVector arg_types;
   std::unique_ptr<Graph> body_graph;
   TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+  FixupSourceAndSinkEdges(body_graph.get());
+  TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
 
   VLOG(2) << "Frame " << frame->name << " condition: "
           << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -487,9 +496,9 @@
   static std::atomic<int64> sequence_num(0LL);
   int64 id = ++sequence_num;
   NameAttrList cond_name;
-  cond_name.set_name(strings::StrCat("_functionalize_cond_", id));
+  cond_name.set_name(absl::StrCat("_functionalize_cond_", id));
   NameAttrList body_name;
-  body_name.set_name(strings::StrCat("_functionalize_body_", id));
+  body_name.set_name(absl::StrCat("_functionalize_body_", id));
   FunctionDef cond_fdef;
   TF_RETURN_IF_ERROR(
       GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef));
@@ -510,7 +519,7 @@
 
   // Builds a While operator.
   NodeDef while_def;
-  NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+  NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
   builder.Attr("T", arg_types);
   builder.Attr("cond", cond_name);
   builder.Attr("body", body_name);
@@ -653,9 +662,9 @@
 
   // There should be no cycle at this point, since while loops have been removed
   // from graph.
-  // Check that the newly added XlaWhile nodes don't feed into themselves.
+  // Check that the newly added While nodes don't feed into themselves.
   for (const Node* node : graph->op_nodes()) {
-    if (node->def().op() == "XlaWhile") {
+    if (node->def().op() == "While") {
       TF_RETURN_WITH_CONTEXT_IF_ERROR(
           CheckNodeNotInCycle(node, graph->num_node_ids()),
           "Functionalizing loop failed.");
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 1ed1fb3..c019a28 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -20,7 +20,6 @@
 #include <vector>
 #include "tensorflow/compiler/tf2xla/const_analysis.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
@@ -81,7 +80,7 @@
       TF_ASSIGN_OR_RETURN(auto literal,
                           client->ComputeConstant(constant_graph));
       TF_RETURN_IF_ERROR(
-          LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
+          LiteralToHostTensor(literal, arg.type, &arg.constant_value));
     } else {
       arg.kind = XlaCompiler::Argument::kParameter;
     }
@@ -127,7 +126,7 @@
     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
         << "Not supported node: " << n->DebugString();
     params.op_kernel = op_kernel.get();
-    gtl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
+    absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
     params.output_attr_array = output_attr.data();
 
     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index 127562e..ab7cac7 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -89,7 +89,7 @@
   ScopedStepContainer* step_container_;
   // A buffer to hold tensor inputs to a node, this is reused across the graph
   // traversal.
-  gtl::InlinedVector<TensorValue, 4> tensor_inputs_;
+  absl::InlinedVector<TensorValue, 4> tensor_inputs_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 4c776fb..46794f7 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -115,9 +115,6 @@
     deps = [
         ":if_op",
         ":while_op",
-        "@com_google_absl//absl/algorithm:container",
-        "@com_google_absl//absl/strings",
-        "@com_google_absl//absl/types:span",
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/lib:batch_dot",
@@ -168,14 +165,11 @@
         "//tensorflow/core/kernels:sparse_to_dense_op",
         "//tensorflow/core/kernels:stack_ops",
         "//tensorflow/core/kernels:training_ops",
-    ] + if_mkl(
-        [
-            "//tensorflow/core/kernels:mkl_transpose_op",
-        ],
-        [
-            "//tensorflow/core/kernels:transpose_op",
-        ],
-    ),
+        "//tensorflow/core/kernels:transpose_op",
+        "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+    ],
 )
 
 tf_kernel_library(
@@ -184,6 +178,7 @@
     hdrs = ["while_op.h"],
     deps = [
         "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/compiler/xla:literal",
@@ -201,6 +196,7 @@
     hdrs = ["if_op.h"],
     deps = [
         "//tensorflow/compiler/tf2xla:common",
+        "//tensorflow/compiler/tf2xla:side_effect_util",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/ops:xla_ops",
         "//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index edced6b..a18e049 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -26,7 +26,7 @@
                   absl::Span<const int64> block_shape,
                   const xla::Literal& crops) {
   const int input_rank = input_tensor_shape.dims();
-  const gtl::InlinedVector<int64, 4> input_shape =
+  const absl::InlinedVector<int64, 4> input_shape =
       input_tensor_shape.dim_sizes();
   const int block_rank = block_shape.size();
 
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index 2e383b1..182f7c9 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -39,7 +39,7 @@
     OP_REQUIRES(
         ctx, ctx->num_inputs() == 2,
         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
-    gtl::InlinedVector<BCast::Vec, 2> shapes;
+    absl::InlinedVector<BCast::Vec, 2> shapes;
     for (int i = 0; i < ctx->num_inputs(); ++i) {
       const TensorShape in_shape = ctx->InputShape(i);
       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
@@ -88,7 +88,7 @@
         ctx, ctx->num_inputs() == 2,
         errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
 
-    gtl::InlinedVector<BCast::Vec, 4> shapes;
+    absl::InlinedVector<BCast::Vec, 4> shapes;
     for (int i = 0; i < ctx->num_inputs(); ++i) {
       const TensorShape in_shape = ctx->InputShape(i);
       OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 12b0e38..e96a1ad 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -48,7 +48,7 @@
     OP_REQUIRES(ctx, kRequiredDims == input_rank,
                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
                                         "; got: ", input_rank));
-    const gtl::InlinedVector<int64, 4> input_shape =
+    const absl::InlinedVector<int64, 4> input_shape =
         input_tensor_shape.dim_sizes();
 
     xla::XlaOp input = ctx->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index a3389d5..4af1e8b 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -34,15 +34,12 @@
       : XlaOpKernel(context) {}
 
   void Compile(XlaOpKernelContext* ctx) override {
-    VLOG(3) << "DynamicUpdateSliceOp::Compile";
+    DataType index_type = ctx->InputType("indices");
+    CHECK(index_type == DT_INT32 || index_type == DT_INT64);
 
-    DataType index_type = input_type(2);
-    OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
-                errors::InvalidArgument("index must be int32 or int64"));
-
-    const TensorShape input_shape = ctx->InputShape(0);
-    const TensorShape update_shape = ctx->InputShape(1);
-    const TensorShape index_shape = ctx->InputShape(2);
+    const TensorShape input_shape = ctx->InputShape("input");
+    const TensorShape update_shape = ctx->InputShape("update");
+    const TensorShape index_shape = ctx->InputShape("indices");
 
     OP_REQUIRES(
         ctx,
@@ -57,13 +54,56 @@
                                 input_shape.DebugString(), "; update shape is ",
                                 update_shape.DebugString()));
 
-    xla::XlaOp result =
-        xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
+    xla::XlaOp result = xla::DynamicUpdateSlice(
+        ctx->Input("input"), ctx->Input("update"), ctx->Input("indices"));
     ctx->SetOutput(0, result);
   }
 };
 
 REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp);
 
+class DynamicSliceOp : public XlaOpKernel {
+ public:
+  explicit DynamicSliceOp(OpKernelConstruction* context)
+      : XlaOpKernel(context) {}
+
+  void Compile(XlaOpKernelContext* ctx) override {
+    DataType index_type = ctx->InputType("start_indices");
+    CHECK(index_type == DT_INT32 || index_type == DT_INT64);
+    CHECK(index_type == ctx->InputType("size_indices"));
+
+    const TensorShape input_shape = ctx->InputShape("input");
+    const TensorShape start_indices_shape = ctx->InputShape("start_indices");
+    const TensorShape size_indices_shape = ctx->InputShape("size_indices");
+
+    OP_REQUIRES(ctx,
+                TensorShapeUtils::IsVector(start_indices_shape) &&
+                    start_indices_shape.num_elements() == input_shape.dims(),
+                errors::InvalidArgument(
+                    "start_indices must be a vector with length equal to "
+                    "input rank, but input rank is ",
+                    input_shape.dims(), " and start_indices has shape ",
+                    start_indices_shape.DebugString()));
+    OP_REQUIRES(ctx,
+                TensorShapeUtils::IsVector(size_indices_shape) &&
+                    size_indices_shape.num_elements() == input_shape.dims(),
+                errors::InvalidArgument(
+                    "size_indices must be a vector with length equal to "
+                    "input rank, but input rank is ",
+                    input_shape.dims(), " and size_indices has shape ",
+                    size_indices_shape.DebugString()));
+
+    std::vector<int64> size_indices;
+    OP_REQUIRES_OK(
+        ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices));
+    xla::XlaOp result = xla::DynamicSlice(
+        ctx->Input("input"), ctx->Input("start_indices"), size_indices);
+    ctx->SetOutput(0, result);
+  }
+};
+
+REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"),
+                DynamicSliceOp);
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6e1dbf5..56da50f 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/tf2xla/kernels/if_op.h"
 
 #include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@
   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+  if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+    has_token_input_output_ = false;
+  } else {
+    has_token_input_output_ = !token_input_nodes_.empty();
+  }
 }
 
 // TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@
   options.resolve_compile_time_constants = false;
   options.return_updated_values_for_all_resources = true;
   options.is_entry_computation = false;
+  options.add_token_input_output = has_token_input_output_;
   XlaCompiler* compiler = ctx->compiler();
 
   XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@
   std::vector<xla::XlaOp> inputs(num_inputs);
   for (int i = 0; i < num_inputs; ++i) {
     int input_num = then_result.input_mapping[i] + 1;
-    if (ctx->input_type(input_num) == DT_RESOURCE) {
+    if (has_token_input_output_ && i == num_inputs - 1) {
+      // Set token input for this "if" op.
+      std::vector<xla::XlaOp> token_inputs;
+      for (const string& node_name : token_input_nodes_) {
+        auto token_or = compiler->GetNodeToken(node_name);
+        OP_REQUIRES_OK(ctx, token_or.status());
+        token_inputs.push_back(token_or.ValueOrDie());
+      }
+      inputs[i] = xla::AfterAll(b, token_inputs);
+    } else if (ctx->input_type(input_num) == DT_RESOURCE) {
       XlaResource* resource;
       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -219,6 +235,18 @@
     }
     ctx->SetOutput(i, output_handle);
   }
+  if (has_token_input_output_) {
+    // Set token output for this "if" op.
+    xla::XlaOp token_output =
+        xla::GetTupleElement(outputs, output_types_.size());
+    auto shape_or = b->GetShape(token_output);
+    OP_REQUIRES_OK(ctx, shape_or.status());
+    OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+                errors::FailedPrecondition(
+                    "Token output is not token type: ",
+                    xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+    OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+  }
 
   // Updates the values of any resource variables modified by the conditional
   // bodies.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
index f9bc98a..7783e13 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.h
@@ -52,6 +52,8 @@
   DataType cond_type_;
   DataTypeVector input_types_;
   DataTypeVector output_types_;
+  bool has_token_input_output_;
+  std::vector<string> token_input_nodes_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 22a45b2..3d81ae9 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -78,14 +78,14 @@
     std::vector<xla::XlaOp> args;
     args.push_back(ctx->Input(0));
     args.push_back(xla::ConstantLiteral(
-        &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+        &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
     if (input_shape.dims() > 1) {
       // Don't bother passing the output shape and dim for the 1d case, since
       // the shape is always a scalar and the dim is always 0.
       args.push_back(xla::ConstantLiteral(
-          &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+          &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
       args.push_back(
-          xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
+          xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
     }
 
     xla::Shape xla_shape =
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index f6f158a..27690c1 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -138,7 +138,7 @@
   int num_dims = num_spatial_dims + 2;
   int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format);
   int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format);
-  gtl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
+  absl::InlinedVector<int64, 4> spatial_dimensions(num_spatial_dims);
   for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) {
     spatial_dimensions[spatial_dim] =
         GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim);
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 5982485..118f279 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -69,7 +69,7 @@
   VLOG(1) << "data shape: " << data_shape.DebugString();
   VLOG(1) << "axes      : " << absl::StrJoin(axes, ",");
 
-  gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
+  absl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
   std::vector<int64> xla_axes;
   int64 num_elements_reduced = 1LL;
   for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
@@ -103,7 +103,7 @@
 
   xla::XlaBuilder* const b = ctx->builder();
   // Construct the builder for the reduction lambda.
-  xla::XlaBuilder r(strings::StrCat(desc, "-reduction"));
+  xla::XlaBuilder r(absl::StrCat(desc, "-reduction"));
   xla::PrimitiveType type;
   TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
 
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index c0afcca..8494864 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -97,7 +97,7 @@
 
     // witnessed_axes is used to ensure that the same axis is not marked to be
     // reversed multiple times.
-    gtl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
+    absl::InlinedVector<bool, 8> witnessed_axes(x_shape.dims(), false);
 
     for (int d = 0; d < axes.size(); ++d) {
       OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 4e0cf99..2e0a69b 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -115,7 +115,7 @@
     // accept legacy scalars, even when they should be forbidden by the graphdef
     // version.
     OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
-                errors::InvalidArgument(strings::StrCat(
+                errors::InvalidArgument(absl::StrCat(
                     "dim input to ExpandDims must be a scalar; got ",
                     dim_shape.DebugString())));
 
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index b7b4f3a..76b79be 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -26,7 +26,7 @@
                   absl::Span<const int64> block_shape,
                   const xla::Literal& paddings) {
   const int input_rank = input_tensor_shape.dims();
-  const gtl::InlinedVector<int64, 4> input_shape =
+  const absl::InlinedVector<int64, 4> input_shape =
       input_tensor_shape.dim_sizes();
   const int block_rank = block_shape.size();
 
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 4493539..3293c13 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -48,7 +48,7 @@
     OP_REQUIRES(ctx, kRequiredDims == input_rank,
                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
                                         "; got ", input_rank));
-    const gtl::InlinedVector<int64, 4> input_shape =
+    const absl::InlinedVector<int64, 4> input_shape =
         input_tensor_shape.dim_sizes();
 
     xla::XlaOp input = ctx->Input(0);
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index df91900..ee70f50 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -111,7 +111,7 @@
     xla::XlaOp value;
     XlaContext& xc = XlaContext::Get(ctx);
     XlaResource* resource;
-    string name = strings::StrCat("Stack: ", stack_name_);
+    string name = absl::StrCat("Stack: ", stack_name_);
     OP_REQUIRES_OK(
         ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
                                TensorShape(), value, /*tensor_array_size=*/size,
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 472d474..2b2e3de 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -46,9 +46,9 @@
     const TensorShape input_shape = ctx->InputShape(0);
 
     TensorShape final_shape;
-    gtl::InlinedVector<int64, 4> begin;
-    gtl::InlinedVector<int64, 4> end;
-    gtl::InlinedVector<int64, 4> strides;
+    absl::InlinedVector<int64, 4> begin;
+    absl::InlinedVector<int64, 4> end;
+    absl::InlinedVector<int64, 4> strides;
 
     xla::Literal begin_literal, end_literal, strides_literal;
     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
@@ -72,8 +72,8 @@
                        shrink_axis_mask_, &dummy_processing_shape, &final_shape,
                        &dummy, &dummy, &dummy, &begin, &end, &strides));
 
-    gtl::InlinedVector<int64, 4> dimensions_to_reverse;
-    gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
+    absl::InlinedVector<int64, 4> dimensions_to_reverse;
+    absl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
 
     for (int i = 0; i < begin.size(); ++i) {
       if (strides[i] > 0) {
@@ -127,9 +127,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape processing_shape, final_shape;
-    gtl::InlinedVector<int64, 4> begin;
-    gtl::InlinedVector<int64, 4> end;
-    gtl::InlinedVector<int64, 4> strides;
+    absl::InlinedVector<int64, 4> begin;
+    absl::InlinedVector<int64, 4> end;
+    absl::InlinedVector<int64, 4> strides;
 
     TensorShape input_shape;
     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
@@ -175,7 +175,7 @@
     grad = xla::Reshape(grad, processing_shape.dim_sizes());
 
     // Pad the input gradients.
-    gtl::InlinedVector<int64, 4> dimensions_to_reverse;
+    absl::InlinedVector<int64, 4> dimensions_to_reverse;
     xla::PaddingConfig padding_config;
 
     for (int i = 0; i < processing_shape.dims(); ++i) {
@@ -238,9 +238,9 @@
 
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape final_shape;
-    gtl::InlinedVector<int64, 4> begin;
-    gtl::InlinedVector<int64, 4> end;
-    gtl::InlinedVector<int64, 4> strides;
+    absl::InlinedVector<int64, 4> begin;
+    absl::InlinedVector<int64, 4> end;
+    absl::InlinedVector<int64, 4> strides;
 
     xla::Literal begin_literal, end_literal, strides_literal;
     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
@@ -287,8 +287,8 @@
 
     xla::XlaOp rhs = ctx->Input(4);
 
-    gtl::InlinedVector<int64, 4> dimensions_to_reverse;
-    gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
+    absl::InlinedVector<int64, 4> dimensions_to_reverse;
+    absl::InlinedVector<int64, 4> slice_begin, slice_dims;
     for (int i = 0; i < begin.size(); ++i) {
       // TODO(phawkins): implement strides != 1
       OP_REQUIRES(
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index bb114d1..94108b7 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -167,7 +167,7 @@
 
     XlaContext& xc = XlaContext::Get(ctx);
     XlaResource* var;
-    string name = strings::StrCat("TensorArray: ", tensor_array_name_);
+    string name = absl::StrCat("TensorArray: ", tensor_array_name_);
     OP_REQUIRES_OK(
         ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
                                dtype_, shape, value, /*tensor_array_size=*/size,
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index f9148b3..6b303b3 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -61,7 +61,7 @@
 
     std::vector<int64> transposed_order;
     // Check whether permutation is a permutation of integers of [0 .. dims).
-    gtl::InlinedVector<bool, 8> bits(dims);
+    absl::InlinedVector<bool, 8> bits(dims);
     bool is_identity = true;
     for (int i = 0; i < dims; ++i) {
       const int32 d = perm[i];
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 2965182..559414e 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/tf2xla/kernels/while_op.h"
 
 #include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -90,6 +91,11 @@
   cond_name_attr_ = *name_attr;
   OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
   body_name_attr_ = *name_attr;
+  if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+    has_token_input_output_ = false;
+  } else {
+    has_token_input_output_ = !token_input_nodes_.empty();
+  }
 }
 
 void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@@ -120,6 +126,7 @@
   body_options.return_updated_values_for_all_resources = true;
   body_options.resolve_compile_time_constants = false;
   body_options.is_entry_computation = false;
+  body_options.add_token_input_output = has_token_input_output_;
   XlaCompiler::CompilationResult body;
   OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
                                                 arguments, &body));
@@ -192,6 +199,7 @@
   cond_options.use_tuple_arg = true;
   cond_options.resolve_compile_time_constants = false;
   cond_options.is_entry_computation = false;
+  cond_options.add_token_input_output = has_token_input_output_;
   XlaCompiler::CompilationResult cond;
   OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
                                                 arguments, &cond));
@@ -238,7 +246,16 @@
   std::vector<xla::XlaOp> inputs(num_inputs);
   for (int i = 0; i < num_inputs; ++i) {
     int input_num = body.input_mapping[i];
-    if (ctx->input_type(input_num) == DT_RESOURCE) {
+    if (has_token_input_output_ && i == num_inputs - 1) {
+      // Set token input for this "while" op.
+      std::vector<xla::XlaOp> token_inputs;
+      for (const string& node_name : token_input_nodes_) {
+        auto token_or = compiler->GetNodeToken(node_name);
+        OP_REQUIRES_OK(ctx, token_or.status());
+        token_inputs.push_back(token_or.ValueOrDie());
+      }
+      inputs[i] = xla::AfterAll(builder, token_inputs);
+    } else if (ctx->input_type(input_num) == DT_RESOURCE) {
       XlaResource* resource;
       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@@ -273,6 +290,18 @@
                      xla::GetTupleElement(while_result, i));
     }
   }
+  if (has_token_input_output_) {
+    // Set token output for this "while" op.
+    xla::XlaOp token_output =
+        xla::GetTupleElement(while_result, ctx->num_outputs());
+    auto shape_or = builder->GetShape(token_output);
+    OP_REQUIRES_OK(ctx, shape_or.status());
+    OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+                errors::FailedPrecondition(
+                    "Token output is not token type: ",
+                    xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+    OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+  }
 
   // Updates the values of any resource variables modified by the loop.
   for (int i = 0; i < body.resource_updates.size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 67edeba..aeeff40 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -56,6 +56,8 @@
  private:
   NameAttrList cond_name_attr_;
   NameAttrList body_name_attr_;
+  bool has_token_input_output_;
+  std::vector<string> token_input_nodes_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
 };
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
index 8848623..fecc7c5 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc
@@ -84,7 +84,7 @@
 
  private:
   xla::ConvolutionDimensionNumbers dnums_;
-  xla::PrecisionConfigProto precision_config_;
+  xla::PrecisionConfig precision_config_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp);
 };
diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
index 2fed53e..40b15b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc
@@ -54,7 +54,7 @@
 
  private:
   xla::DotDimensionNumbers dnums_;
-  xla::PrecisionConfigProto precision_config_;
+  xla::PrecisionConfig precision_config_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
 };
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 9365d20..8597e7f 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -205,7 +205,7 @@
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/client:xla_computation",
-        "//tensorflow/core:lib",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
     ],
 )
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index d8c050d..64f2d78 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -28,7 +28,7 @@
 
 xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
                     bool transpose_y, bool conjugate_x, bool conjugate_y,
-                    xla::PrecisionConfigProto::Precision precision) {
+                    xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = x.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
@@ -96,7 +96,7 @@
       y = xla::Conj(y);
     }
 
-    xla::PrecisionConfigProto precision_proto;
+    xla::PrecisionConfig precision_proto;
     precision_proto.add_operand_precision(precision);
     precision_proto.add_operand_precision(precision);
 
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 6cfccd5..6edd63a 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -43,11 +43,11 @@
 // It is computed as:
 //
 //     output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
-                    bool transpose_y = false, bool conjugate_x = false,
-                    bool conjugate_y = false,
-                    xla::PrecisionConfigProto::Precision precision =
-                        xla::PrecisionConfigProto::DEFAULT);
+xla::XlaOp BatchDot(
+    xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
+    bool transpose_y = false, bool conjugate_x = false,
+    bool conjugate_y = false,
+    xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index c50a8de..ab3d0a5 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -50,7 +50,7 @@
 //                       l[..., j, j]
 //   return l
 xla::XlaOp CholeskyUnblocked(xla::XlaOp a,
-                             xla::PrecisionConfigProto::Precision precision) {
+                             xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
@@ -150,7 +150,7 @@
 }  // namespace
 
 xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size,
-                    xla::PrecisionConfigProto::Precision precision) {
+                    xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 60cd7de..9a561c3 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -30,9 +30,9 @@
 // TODO(phawkins): check for negative values on the diagonal and return an
 // error, instead of silently yielding NaNs.
 // TODO(znado): handle the complex Hermitian case
-xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256,
-                    xla::PrecisionConfigProto::Precision precision =
-                        xla::PrecisionConfigProto::HIGHEST);
+xla::XlaOp Cholesky(
+    xla::XlaOp a, int64 block_size = 256,
+    xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
index 0a140fa..6b3f2b6 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.cc
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -150,7 +150,7 @@
   xla::XlaOp vs;    // Shape: [..., m, n]
 };
 xla::StatusOr<QRBlockResult> QRBlock(
-    xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) {
+    xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
   const int num_dims = xla::ShapeUtil::Rank(a_shape);
@@ -257,7 +257,7 @@
 xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
     xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
     xla::XlaOp taus, int64 m, int64 n,
-    xla::PrecisionConfigProto::Precision precision) {
+    xla::PrecisionConfig::Precision precision) {
   std::vector<int64> batch_dim_indices(batch_dims.size());
   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
   int64 n_index = batch_dims.size() + 1;
@@ -332,7 +332,7 @@
 // rather than WY transformations.
 xla::StatusOr<QRDecompositionResult> QRDecomposition(
     xla::XlaOp a, bool full_matrices, int64 block_size,
-    xla::PrecisionConfigProto::Precision precision) {
+    xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
   const int num_dims = xla::ShapeUtil::Rank(a_shape);
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
index 8a389fb..24b537a 100644
--- a/tensorflow/compiler/tf2xla/lib/qr.h
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -35,8 +35,7 @@
 
 xla::StatusOr<QRDecompositionResult> QRDecomposition(
     xla::XlaOp a, bool full_matrices, int64 block_size = 128,
-    xla::PrecisionConfigProto::Precision precision =
-        xla::PrecisionConfigProto::HIGHEST);
+    xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 37b2240..6524c2a 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -110,9 +110,9 @@
   });
 }
 
-xla::XlaOp InvertDiagonalBlocks(
-    xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a,
-    xla::PrecisionConfigProto::Precision precision) {
+xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower,
+                                bool transpose_a, bool conjugate_a,
+                                xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = diag_blocks.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     // Input is a batch of square lower triangular square matrices. Its shape is
@@ -216,7 +216,7 @@
       dnums.add_rhs_batch_dimensions(0);
       dnums.add_lhs_contracting_dimensions(2);
       dnums.add_rhs_contracting_dimensions(1);
-      xla::PrecisionConfigProto precision_proto;
+      xla::PrecisionConfig precision_proto;
       precision_proto.add_operand_precision(precision);
       precision_proto.add_operand_precision(precision);
       auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
@@ -245,7 +245,7 @@
 xla::XlaOp SolveWithInvertedDiagonalBlocks(
     xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side,
     bool lower, bool transpose_a, bool conjugate_a,
-    xla::PrecisionConfigProto::Precision precision) {
+    xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape,
@@ -346,7 +346,7 @@
 xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
                            bool lower, bool transpose_a, bool conjugate_a,
                            int64 block_size,
-                           xla::PrecisionConfigProto::Precision precision) {
+                           xla::PrecisionConfig::Precision precision) {
   xla::XlaBuilder* builder = a.builder();
   return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
     TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index ac42a48..2303234 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -57,11 +57,10 @@
 //
 // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
 // blocking is used.
-xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
-                           bool lower, bool transpose_a, bool conjugate_a,
-                           int64 block_size = 128,
-                           xla::PrecisionConfigProto::Precision precision =
-                               xla::PrecisionConfigProto::HIGHEST);
+xla::XlaOp TriangularSolve(
+    xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a,
+    bool conjugate_a, int64 block_size = 128,
+    xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index c267848..804671f 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -64,31 +64,31 @@
   xla::Literal literal;
   switch (type) {
     case xla::U8:
-      literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
+      literal = xla::LiteralUtil::CreateR0<uint8>(value);
       break;
     case xla::U32:
-      literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
+      literal = xla::LiteralUtil::CreateR0<uint32>(value);
       break;
     case xla::U64:
-      literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
+      literal = xla::LiteralUtil::CreateR0<uint64>(value);
       break;
     case xla::S8:
-      literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
+      literal = xla::LiteralUtil::CreateR0<int8>(value);
       break;
     case xla::S32:
-      literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
+      literal = xla::LiteralUtil::CreateR0<int32>(value);
       break;
     case xla::S64:
-      literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
+      literal = xla::LiteralUtil::CreateR0<int64>(value);
       break;
     case xla::F32:
-      literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
+      literal = xla::LiteralUtil::CreateR0<float>(value);
       break;
     case xla::F64:
-      literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
+      literal = xla::LiteralUtil::CreateR0<double>(value);
       break;
     case xla::C64:
-      literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
+      literal = xla::LiteralUtil::CreateR0<complex64>(value);
       break;
     case xla::PRED:
       LOG(FATAL) << "pred element type is not integral";
@@ -96,12 +96,12 @@
     case xla::U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case xla::BF16:
-      literal = std::move(
-          *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+      literal =
+          xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value));
       break;
     case xla::F16:
-      literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
-          static_cast<xla::half>(value)));
+      literal =
+          xla::LiteralUtil::CreateR0<xla::half>(static_cast<xla::half>(value));
       break;
     case xla::TUPLE:
       LOG(FATAL) << "tuple element type is not integral";
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 5300e2c..594ab1d 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -24,7 +24,7 @@
 xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
     const LoopConditionFunction& condition_function,
     const LoopBodyFunction& body_function,
-    absl::Span<const xla::XlaOp> initial_values, StringPiece name,
+    absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
     xla::XlaBuilder* builder) {
   int arity = initial_values.size();
   std::vector<xla::Shape> var_shapes;
@@ -47,7 +47,7 @@
 
   // Build the condition.
   std::unique_ptr<xla::XlaBuilder> cond_builder =
-      builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
+      builder->CreateSubBuilder(absl::StrCat(name, "_condition"));
   {
     auto parameter =
         xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
@@ -61,7 +61,7 @@
 
   // Build the body.
   std::unique_ptr<xla::XlaBuilder> body_builder =
-      builder->CreateSubBuilder(strings::StrCat(name, "_body"));
+      builder->CreateSubBuilder(absl::StrCat(name, "_body"));
   {
     auto parameter =
         xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
@@ -84,7 +84,7 @@
 xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
     int64 num_iterations, xla::PrimitiveType num_iterations_type,
     const ForEachIndexBodyFunction& body_function,
-    absl::Span<const xla::XlaOp> initial_values, StringPiece name,
+    absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
     xla::XlaBuilder* builder) {
   auto while_cond_fn =
       [&](absl::Span<const xla::XlaOp> values,
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h
index 115ebf3..f2134bb 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.h
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.h
@@ -19,11 +19,11 @@
 #include <functional>
 #include <vector>
 
+#include "absl/strings/string_view.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
 
 namespace tensorflow {
 
@@ -50,7 +50,7 @@
 xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
     const LoopConditionFunction& condition_function,
     const LoopBodyFunction& body_function,
-    absl::Span<const xla::XlaOp> initial_values, StringPiece name,
+    absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
     xla::XlaBuilder* builder);
 
 // Builds an XLA loop that repeats a computation `num_iterations` times.
@@ -65,7 +65,7 @@
 xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
     int64 num_iterations, xla::PrimitiveType num_iterations_type,
     const ForEachIndexBodyFunction& body_function,
-    absl::Span<const xla::XlaOp> initial_values, StringPiece name,
+    absl::Span<const xla::XlaOp> initial_values, absl::string_view name,
     xla::XlaBuilder* builder);
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index 7dc16b5..15f4c38 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -22,51 +22,61 @@
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
+namespace {
 
 TEST(LiteralUtil, LiteralToHostTensor) {
   // int64 literal can only be converted to an int64 host tensor.
-  {
-    std::vector<int64> int64_values = {1, 2, 3};
-    std::unique_ptr<xla::Literal> int64_values_literal =
-        xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
-    Tensor host_tensor;
-    EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
-              LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
-                  .error_message());
-    EXPECT_EQ(
-        "Cannot convert literal of type S64 to tensor of type qint32",
-        LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
-            .error_message());
-    EXPECT_TRUE(
-        LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
-            .ok());
-    test::ExpectTensorEqual<int64>(host_tensor,
-                                   test::AsTensor<int64>(int64_values));
-  }
-
-  {
-    // Repeat tests with int32.
-    Tensor host_tensor;
-    std::vector<int32> int32_values = {10, 11};
-    std::unique_ptr<xla::Literal> int32_values_literal =
-        xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
-    EXPECT_TRUE(
-        LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
-            .ok());
-    test::ExpectTensorEqual<int32>(host_tensor,
-                                   test::AsTensor<int32>(int32_values));
-
-    EXPECT_TRUE(
-        LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
-            .ok());
-    std::vector<qint32> qint32_values = {10, 11};
-    test::ExpectTensorEqual<qint32>(host_tensor,
-                                    test::AsTensor<qint32>(qint32_values));
-
-    EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
-              LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
-                  .error_message());
-  }
+  std::vector<int64> int64_values = {1, 2, 3};
+  xla::Literal int64_values_literal =
+      xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
+  Tensor host_tensor;
+  EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
+            LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+                .error_message());
+  EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+            LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
+                .error_message());
+  EXPECT_TRUE(
+      LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
+  test::ExpectTensorEqual<int64>(host_tensor,
+                                 test::AsTensor<int64>(int64_values));
 }
 
+template <class T>
+using LiteralUtilTest = ::testing::Test;
+using Types =
+    ::testing::Types<std::pair<int8, qint8>, std::pair<uint8, quint8>,
+                     std::pair<int16, qint16>, std::pair<uint16, quint16>,
+                     std::pair<int32, qint32>>;
+
+TYPED_TEST_CASE(LiteralUtilTest, Types);
+
+TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) {
+  using int_type = typename TypeParam::first_type;
+  using qint_type = typename TypeParam::second_type;
+
+  Tensor host_tensor;
+  std::vector<int_type> int_values = {10, 11};
+  xla::Literal int_values_literal =
+      xla::LiteralUtil::CreateR1(absl::Span<const int_type>(int_values));
+  EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+                                  DataTypeToEnum<int_type>::value, &host_tensor)
+                  .ok());
+  test::ExpectTensorEqual<int_type>(host_tensor,
+                                    test::AsTensor<int_type>(int_values));
+
+  EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+                                  DataTypeToEnum<qint_type>::value,
+                                  &host_tensor)
+                  .ok());
+  std::vector<qint_type> qint_values = {10, 11};
+  test::ExpectTensorEqual<qint_type>(host_tensor,
+                                     test::AsTensor<qint_type>(qint_values));
+
+  EXPECT_EQ(
+      error::INVALID_ARGUMENT,
+      LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code());
+}
+
+}  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 2cd9ae7..0236350 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -83,7 +83,7 @@
 rhs_dilation: dilation to apply between kernel elements
 feature_group_count: number of feature groups for grouped convolution.
 dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
-precision_config: a serialized xla::PrecisionConfigProto proto.
+precision_config: a serialized xla::PrecisionConfig proto.
 )doc");
 
 REGISTER_OP("XlaDot")
@@ -102,7 +102,36 @@
 lhs: the LHS tensor
 rhs: the RHS tensor
 dimension_numbers: a serialized xla::DotDimensionNumbers proto.
-precision_config: a serialized xla::PrecisionConfigProto proto.
+precision_config: a serialized xla::PrecisionConfig proto.
+)doc");
+
+REGISTER_OP("XlaDynamicSlice")
+    .Input("input: T")
+    .Input("start_indices: Tindices")
+    .Input("size_indices: Tindices")
+    .Output("output: T")
+    .Attr("T: type")
+    .Attr("Tindices: {int32, int64}")
+    .SetShapeFn(shape_inference::UnknownShape)
+    .Doc(R"doc(
+Wraps the XLA DynamicSlice operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
+.
+
+DynamicSlice extracts a sub-array from the input array at dynamic
+start_indices. The size of the slice in each dimension is passed in
+size_indices, which specify the end point of exclusive slice intervals in each
+dimension -- [start, start + size). The shape of start_indices must be rank ==
+1, with dimension size equal to the rank of operand.
+
+input: A `Tensor` of type T.
+
+start_indices: Rank 1 tensor of N integers containing the starting indices of
+  the slice for each dimension. Value must be greater than or equal to zero.
+
+start_indices: List of N integers containing the slice size for each
+  dimension. Each value must be strictly greater than zero, and start + size
+  must be less
 )doc");
 
 REGISTER_OP("XlaDynamicUpdateSlice")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 3626de3..27dd18a 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -291,13 +291,7 @@
       name=name)
 
 
-def dynamic_slice(x, starts, sizes, name=None):
-  # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
-  # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
-  # slice if the slice is out of bounds.
-  return array_ops.slice(x, starts, sizes, name=name)
-
-
+dynamic_slice = gen_xla_ops.xla_dynamic_slice
 dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
 
 # TODO(phawkins): generalize tf.pad to support interior padding, and then remove
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 32ba6df..20f2ce2 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/lib/gtl/flatmap.h"
 
 namespace tensorflow {
-/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString(
+/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
     XlaResourceOpKind op_kind) {
   switch (op_kind) {
     case XlaResourceOpKind::kRead:
@@ -30,11 +30,11 @@
   }
 }
 
-static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
-  gtl::FlatMap<StringPiece, XlaResourceOpInfo>* result =
-      new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
+static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
+CreateResourceOpInfoMap() {
+  auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
 
-  auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
+  auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
                  XlaResourceKind resource_kind) {
     auto insert_result =
         result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
@@ -103,23 +103,23 @@
   return result;
 }
 
-static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
+static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
 GetStaticResourceOpInfoMap() {
-  static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
+  static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
       CreateResourceOpInfoMap();
   return *op_info_map;
 }
 
-const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) {
-  const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
+const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
+  const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
       GetStaticResourceOpInfoMap();
   auto it = op_infos.find(op);
   return it == op_infos.end() ? nullptr : &it->second;
 }
 
 namespace resource_op_table_internal {
-std::vector<StringPiece> GetKnownResourceOps() {
-  std::vector<StringPiece> result;
+std::vector<absl::string_view> GetKnownResourceOps() {
+  std::vector<absl::string_view> result;
   for (const auto& p : GetStaticResourceOpInfoMap()) {
     result.push_back(p.first);
   }
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h
index 7f627a6..61c7a56 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.h
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.h
@@ -19,7 +19,7 @@
 #include <string>
 #include <vector>
 
-#include "tensorflow/core/lib/core/stringpiece.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/platform/logging.h"
 
 // Exposes information about the resource operations supported by tf2xla in a
@@ -47,7 +47,7 @@
   XlaResourceOpKind kind() const { return op_kind_; }
   XlaResourceKind resource_kind() const { return resource_kind_; }
 
-  static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind);
+  static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind);
 
  private:
   XlaResourceOpKind op_kind_;
@@ -57,13 +57,13 @@
 // Returns a XlaResourceOpInfo describing `op` if it is a resource operation
 // supported by tf2xla, otherwise returns null (i.e. if this returns null then
 // `op` is either not a resource operation or is unsupported by XLA).
-const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op);
+const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op);
 
 namespace resource_op_table_internal {
 // NB! Implementation detail exposed for unit testing, do not use.
 //
 // Returns the set of resource operations known by this module.
-std::vector<StringPiece> GetKnownResourceOps();
+std::vector<absl::string_view> GetKnownResourceOps();
 }  // namespace resource_op_table_internal
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
index 0343f80..a85ef04 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
@@ -34,7 +34,7 @@
 
 TEST(ResourceOperationTableTest, HaveAllResourceOps) {
   gtl::FlatMap<string, bool> known_resource_ops;
-  for (StringPiece known_resource_op :
+  for (absl::string_view known_resource_op :
        resource_op_table_internal::GetKnownResourceOps()) {
     ASSERT_TRUE(
         known_resource_ops.insert({string(known_resource_op), false}).second);
diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc
index 2d7eb8b..8aae498 100644
--- a/tensorflow/compiler/tf2xla/sharding_util.cc
+++ b/tensorflow/compiler/tf2xla/sharding_util.cc
@@ -17,7 +17,6 @@
 #include "absl/strings/match.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/util/device_name_utils.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
new file mode 100644
index 0000000..6cd7b24
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
+
+const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
+  std::set<std::string> results;
+  Node* first_side_effecting_node_on_path = nullptr;
+  ReverseDFS(g,
+             [&](Node* n) {
+               std::vector<string> token_input_nodes;
+               if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
+                                &token_input_nodes)
+                        .ok() ||
+                   token_input_nodes.empty()) {
+                 return;
+               }
+
+               if (first_side_effecting_node_on_path != nullptr) {
+                 return;
+               }
+
+               first_side_effecting_node_on_path = n;
+               results.insert(n->name());
+             },
+             [&](Node* n) {
+               if (first_side_effecting_node_on_path == n) {
+                 first_side_effecting_node_on_path = nullptr;
+               }
+             },
+             NodeComparatorName());
+  return results;
+}
+
+bool HasSideEffectingNodes(const Graph& g) {
+  for (Node* n : g.nodes()) {
+    std::vector<string> token_input_nodes;
+    if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
+            .ok() &&
+        !token_input_nodes.empty()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
new file mode 100644
index 0000000..ad07624
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Side-effecting nodes will have this attribute set. Its value is the list of
+// node names which this node has side-effect dependencies on.
+//
+// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
+// because they always have side-effect.
+// If and While nodes may or may not have this attribute, depending on whether
+// their bodies have side-effecting nodes.
+extern const char kXlaTokenInputNodesAttrName[];
+
+// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
+// node has side-effect dependency on current graph's token input.
+extern const char kXlaTokenArgNodeName[];
+
+// Calculates side-effect dependencies for the graph's token output.
+// Returns a set of node names representing these dependencies.
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
+
+// Returns whether a graph contains side-effecting nodes.
+bool HasSideEffectingNodes(const Graph& g);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index f34af2d..b22d538 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -22,8 +22,10 @@
 #include <utility>
 #include <vector>
 
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_join.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -41,7 +43,6 @@
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -75,7 +76,7 @@
     auto node_it = node_map.find(remap_it->second);
     if (node_it == node_map.end()) {
       // Strip off the aot_feed_#/ prefix.
-      StringPiece name(remap_it->second);
+      absl::string_view name(remap_it->second);
       const auto index = name.find('/');
       if (index > 0) name.remove_prefix(index + 1);
       return errors::InvalidArgument(
@@ -89,7 +90,7 @@
     // explicitly specify or override them.
     Node* arg_node = nullptr;
     TF_RETURN_IF_ERROR(
-        NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
+        NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp)
             .Attr("T", BaseType(feed_node->output_type(output_index)))
             .Attr("index", arg_index)
             .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
@@ -136,7 +137,7 @@
     // Connects fetch_node -> retval_node.
     Node* retval_node = nullptr;
     TF_RETURN_IF_ERROR(
-        NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
+        NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp)
             .Input(fetch_node, id.output_index())
             .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
             .Attr("index", ret_index)
@@ -256,7 +257,7 @@
   XlaOpRegistry::RegisterCompilationKernels();
   for (Node* node : graph->nodes()) {
     node->set_assigned_device_name(
-        strings::StrCat("/device:", DEVICE_CPU_XLA_JIT));
+        absl::StrCat("/device:", DEVICE_CPU_XLA_JIT));
   }
   std::vector<XlaCompiler::Argument> xla_args;
   TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
@@ -340,6 +341,13 @@
   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
                                             second_copy_def, g.get()));
   TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+
+  // Functionalize control flow.
+  TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
+  // After control flow functionalization, we might have more FunctionDef's
+  // (then/else branch, loop body). Add them to the graph.
+  TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
+
   *graph = std::move(g);
   return Status::OK();
 }
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 56f7045..ab26d93 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -77,8 +77,8 @@
   // Set up arguments.
   auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
   auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
-  auto x_global_or = client->TransferToServer(*x_literal);
-  auto y_global_or = client->TransferToServer(*y_literal);
+  auto x_global_or = client->TransferToServer(x_literal);
+  auto y_global_or = client->TransferToServer(y_literal);
   TF_EXPECT_OK(x_global_or.status());
   TF_EXPECT_OK(y_global_or.status());
   std::unique_ptr<xla::GlobalData> x_global =
@@ -90,8 +90,8 @@
   auto result_or =
       client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
   TF_EXPECT_OK(result_or.status());
-  std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
-  EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+  xla::Literal result = std::move(result_or.ValueOrDie());
+  EXPECT_EQ("(s32[]) (\n42\n)", result.ToString());
 
   config.mutable_feed(0)->mutable_id()->set_output_index(
       123); /* invalid output_index */
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index e284e0b..d6f42ba 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -20,20 +20,23 @@
 #include <set>
 #include <unordered_map>
 
+#include "absl/strings/str_cat.h"
 #include "absl/types/optional.h"
 #include "tensorflow/compiler/tf2xla/sharding_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/graph/tensor_id.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 
 namespace tensorflow {
 
@@ -75,6 +78,8 @@
 
 }  // namespace
 
+const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
+
 Status ValidateConfig(const tf2xla::Config& config) {
   std::set<string> names;
   for (const tf2xla::Feed& feed : config.feed()) {
@@ -112,8 +117,8 @@
     const string name_port = TensorIdToString(feed->id());
     PlaceholderInfo& info = placeholder_info[name_port];
     info.feed = feed;
-    info.placeholder_name = strings::StrCat(
-        "aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
+    info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
+                                         "/", feed->id().node_name());
     (*feed_remapping)[name_port] = info.placeholder_name;
   }
 
@@ -258,7 +263,7 @@
 }
 
 string TensorIdToString(const tf2xla::TensorId& id) {
-  return strings::StrCat(id.node_name(), ":", id.output_index());
+  return absl::StrCat(id.node_name(), ":", id.output_index());
 }
 
 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
@@ -289,7 +294,7 @@
   return Status::OK();
 }
 
-void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
                                    KernelDef* kdef) {
   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
     if (constraint.name() == name) {
@@ -323,4 +328,101 @@
   return counter.fetch_add(2);
 }
 
+// TODO(b/77601805): add tests for associated function related stuff.
+bool HasAssociatedFunction(const NodeDef& node_def,
+                           FunctionLibraryRuntime* flr) {
+  if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) {
+    return true;
+  }
+
+  if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
+    // Skip gradient op. Gradient op has "f" attr, which is set to the function
+    // we are getting gradient for. That function is not associated with the op.
+    return false;
+  }
+
+  for (const auto& iter : node_def.attr()) {
+    if (iter.second.has_func()) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+    const Node& node, FunctionLibraryRuntime* flr) {
+  std::vector<AssociatedFunctionInfo> results;
+  const string& op = node.type_string();
+  if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
+    // This is a function call node.
+    AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+    results.emplace_back(AssociatedFunctionInfo(op, attrs));
+  } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
+    // Skip gradient op. Gradient op has "f" attr, which is set to the function
+    // we are getting gradient for. That function is not associated with the op.
+  } else {
+    // Collect all function attrs for the node.
+    for (auto& iter : node.attrs()) {
+      if (iter.second.has_func()) {
+        VLOG(2) << "Found function attr for node " << node.name() << ": "
+                << iter.first << " = " << iter.second.func().name();
+        results.emplace_back(AssociatedFunctionInfo(
+            iter.second.func().name(), iter.second.func().attr(), iter.first));
+      }
+    }
+  }
+  return results;
+}
+
+Status RewriteAssociatedFunction(
+    Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+    const AssociatedFunctionInfo& associated_function,
+    const string& rewritten_function_name) {
+  switch (associated_function.type()) {
+    case AssociatedFunctionInfo::kFunctionCallNode: {
+      // Change this node to call the new function.
+      NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
+      for (auto attr : node->attrs()) {
+        builder.Attr(attr.first, attr.second);
+      }
+      for (int i = 0; i < node->num_inputs(); i++) {
+        Node* input_node;
+        TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
+        builder.Input(input_node->name(), i, node->input_type(i));
+      }
+      builder.Device(node->assigned_device_name().empty()
+                         ? node->requested_device()
+                         : node->assigned_device_name());
+      NodeDef node_def;
+      TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
+      Status s;
+      Node* new_node = graph->AddNode(node_def, &s);
+      TF_RETURN_IF_ERROR(s);
+      for (auto edge : node->in_edges()) {
+        graph->AddEdge(edge->src(), edge->src_output(), new_node,
+                       edge->dst_input());
+      }
+      for (auto edge : node->out_edges()) {
+        graph->AddEdge(new_node, edge->src_output(), edge->dst(),
+                       edge->dst_input());
+      }
+      graph->RemoveNode(node);
+      break;
+    }
+    case AssociatedFunctionInfo::kFunctionAttr: {
+      // Change function attr to rewritten functions.
+      NameAttrList func;
+      TF_RETURN_IF_ERROR(
+          GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
+      node->ClearAttr(associated_function.attr_name());
+      func.set_name(rewritten_function_name);
+      node->AddAttr(associated_function.attr_name(), func);
+      break;
+    }
+  }
+
+  return Status::OK();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 33620ef..6065d0b 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -19,6 +19,7 @@
 #include <unordered_map>
 
 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/kernel_def.pb.h"
 #include "tensorflow/core/framework/op.h"
@@ -53,12 +54,73 @@
 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
 
 // Add an allowed data type to the AttrConstraint with the given name.
-void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype,
+void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
                                    KernelDef* kdef);
 
 // Returns the next random seed to use for seeding xla rng.
 uint32 GetXLARandomSeed();
 
+// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
+// a function call, or the node has function attrs).
+class AssociatedFunctionInfo {
+ public:
+  enum AssociatedFunctionType {
+    kFunctionCallNode = 0,
+    kFunctionAttr = 1,
+  };
+
+  // The node is a function call.
+  AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
+      : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
+
+  // The function is an attr of the node.
+  AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
+                         const string& attr_name)
+      : type_(kFunctionAttr),
+        func_name_(func_name),
+        attrs_(attrs),
+        attr_name_(attr_name) {}
+
+  AssociatedFunctionType type() const { return type_; }
+
+  const string& func_name() const { return func_name_; }
+
+  const string& attr_name() const { return attr_name_; }
+
+  const AttrValueMap& attrs() const { return attrs_; }
+
+ private:
+  // Available for all instances.
+  AssociatedFunctionType type_;
+  string func_name_;
+  AttrValueMap attrs_;
+
+  // Only available if the function is defined in an attr.
+  string attr_name_;
+};
+
+// Returns if the NodeDef has associated function.
+bool HasAssociatedFunction(const NodeDef& node_def,
+                           FunctionLibraryRuntime* flr);
+
+// Gets functions associated with the node. Current cases:
+// 1. For function call node, its function name;
+// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+    const Node& node, FunctionLibraryRuntime* flr);
+
+// Changes associated functions for the node. Current cases:
+// 1. For function call node, creates a new node with the new function name and
+//    remove the old node;
+// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+Status RewriteAssociatedFunction(
+    Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+    const AssociatedFunctionInfo& associated_function,
+    const string& rewritten_function_name);
+
+// Attribute to mark nodes to be executed on host.
+extern const char kXlaOutsideCompilationAttrName[];
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
index 2b1f724..68441b3 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
@@ -16,6 +16,8 @@
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 
 #include "absl/strings/match.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/data_flow_ops.h"
 #include "tensorflow/cc/ops/function_ops.h"
@@ -25,8 +27,6 @@
 #include "tensorflow/core/graph/graph.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -153,7 +153,7 @@
   tf2xla::Config config;
   for (const auto& fetch_node_name : fetches) {
     auto* fetch = config.add_fetch();
-    fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
+    fetch->set_name(absl::StrCat("fetch_", fetch_node_name));
     fetch->mutable_id()->set_node_name(fetch_node_name);
   }
   return config;
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index c969212..d00b137 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -26,21 +26,26 @@
       *type = xla::PRED;
       return Status::OK();
     case tensorflow::DT_INT8:
+    case tensorflow::DT_QINT8:
       *type = xla::S8;
       return Status::OK();
     case tensorflow::DT_INT16:
+    case tensorflow::DT_QINT16:
       *type = xla::S16;
       return Status::OK();
     case tensorflow::DT_INT32:
+    case tensorflow::DT_QINT32:
       *type = xla::S32;
       return Status::OK();
     case tensorflow::DT_INT64:
       *type = xla::S64;
       return Status::OK();
     case tensorflow::DT_UINT8:
+    case tensorflow::DT_QUINT8:
       *type = xla::U8;
       return Status::OK();
     case tensorflow::DT_UINT16:
+    case tensorflow::DT_QUINT16:
       *type = xla::U16;
       return Status::OK();
     case tensorflow::DT_UINT32:
@@ -64,12 +69,6 @@
     case tensorflow::DT_COMPLEX64:
       *type = xla::C64;
       return Status::OK();
-    case tensorflow::DT_QUINT8:
-      *type = xla::U8;
-      return Status::OK();
-    case tensorflow::DT_QINT32:
-      *type = xla::S32;
-      return Status::OK();
     default:
       return errors::InvalidArgument(
           "Unsupported type in DataTypeToPrimitiveType ",
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index d98237b..7f86050 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -76,12 +76,11 @@
 
 XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
                                            DeviceType type)
-    : LocalDevice(
-          options,
-          Device::BuildDeviceAttributes(
-              strings::StrCat("/device:", type.type(), ":0"), type,
-              Bytes(256 << 20), DeviceLocality(),
-              strings::StrCat("device: XLA compilation device ", type.type()))),
+    : LocalDevice(options, Device::BuildDeviceAttributes(
+                               absl::StrCat("/device:", type.type(), ":0"),
+                               type, Bytes(256 << 20), DeviceLocality(),
+                               absl::StrCat("device: XLA compilation device ",
+                                            type.type()))),
       allocator_(new XlaCompilationAllocator()) {}
 
 XlaCompilationDevice::~XlaCompilationDevice() {}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 0c300c2..105f3b6 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,10 +20,10 @@
 
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@@ -149,6 +149,9 @@
     TF_RETURN_WITH_CONTEXT_IF_ERROR(
         GetFunctionBody(function, flib_runtime_, fbody),
         "Local lookup failed with: ", status.error_message());
+    VLOG(4) << "Function " << function.name() << " in flib_runtime_";
+  } else {
+    VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
   }
   return Status::OK();
 }
@@ -198,14 +201,14 @@
   // lowest-numbered core that consumes the argument. We choose the
   // lowest-numbered core so the assignment is deterministic.
   for (Node* n : graph->nodes()) {
-    if (StringPiece(n->type_string()) == "_Arg") {
+    if (absl::string_view(n->type_string()) == "_Arg") {
       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
     }
   }
   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
   // may have gotten a device assignment from the first loop).
   for (Node* n : graph->nodes()) {
-    if (StringPiece(n->type_string()) == "_Retval") {
+    if (absl::string_view(n->type_string()) == "_Retval") {
       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
     }
   }
@@ -213,8 +216,7 @@
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "XlaCompiler::CompileFunction: "
             << dump_graph::DumpGraphToFile(
-                   strings::StrCat("xla_compile_function_", function_id),
-                   *graph);
+                   absl::StrCat("xla_compile_function_", function_id), *graph);
   }
 
   VLOG(1) << "====================================================";
@@ -292,6 +294,10 @@
               "Invalid resource type in XLAShapeForArgument()");
       }
     }
+    case XlaCompiler::Argument::kToken: {
+      *xla_shape = xla::ShapeUtil::MakeTokenShape();
+      return Status::OK();
+    }
     case XlaCompiler::Argument::kInvalid:
       return errors::Internal("Invalid argument type in XLAShapeForArgument()");
   }
@@ -490,7 +496,8 @@
         }
 
         break;
-      case XlaCompiler::Argument::kParameter: {
+      case XlaCompiler::Argument::kParameter:
+      case XlaCompiler::Argument::kToken: {
         input_mapping->push_back(i);
         break;
       }
@@ -522,7 +529,7 @@
 
   // Use the _Arg nodes in the graph to resolve core assignments.
   for (const Node* n : graph.nodes()) {
-    if (StringPiece(n->type_string()) != "_Arg") continue;
+    if (absl::string_view(n->type_string()) != "_Arg") continue;
     int index;
     TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
     TF_RET_CHECK(index >= 0 && index < args.size())
@@ -581,7 +588,7 @@
           builder, core == -1 ? absl::optional<xla::OpSharding>()
                               : xla::sharding_builder::AssignDevice(core));
       arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
-                                      strings::StrCat("arg", i));
+                                      absl::StrCat("arg", i));
     }
   }
 
@@ -617,6 +624,10 @@
           arg_expression.set_handle(arg_handles[i]);
         }
         break;
+      case XlaCompiler::Argument::kToken: {
+        arg_expression.set_handle(arg_handles[i]);
+        break;
+      }
       case XlaCompiler::Argument::kConstant:
       case XlaCompiler::Argument::kInvalid:
         return errors::Internal(
@@ -644,7 +655,7 @@
   // dependency edge to the _SOURCE node.
   for (int64 i = 0; i < ctx->num_inputs(); ++i) {
     Node* node;
-    string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
+    string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg");
     Status status = NodeBuilder(name, "_Arg")
                         .ControlInput(graph->source_node())
                         .Attr("T", ctx->input_dtype(i))
@@ -657,7 +668,7 @@
   // Similarly with return values, create dummy _Retval nodes fed by `node`.
   for (int64 i = 0; i < ctx->num_outputs(); ++i) {
     Node* node;
-    string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
+    string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval");
     Status status = NodeBuilder(name, "_Retval")
                         .Input(main_node, i)
                         .Attr("T", ctx->expected_output_dtype(i))
@@ -693,7 +704,7 @@
                      const DeviceType& device_type, const string& name) {
   auto maybe_error = [&](const Node* node, const Status& s) -> Status {
     if (!s.ok()) {
-      return errors::InvalidArgument(strings::StrCat(
+      return errors::InvalidArgument(absl::StrCat(
           "Detected unsupported operations when trying to compile graph ", name,
           " on ", device_type.type_string(), ": ", node->def().op(), " (",
           s.error_message(), ")", FormatNodeForError(*node)));
@@ -734,18 +745,13 @@
   if (VLOG_IS_ON(2)) {
     VLOG(2) << "XlaCompiler::CompileGraph: "
             << dump_graph::DumpGraphToFile(
-                   strings::StrCat("xla_compile_graph_", name), *graph);
+                   absl::StrCat("xla_compile_graph_", name), *graph,
+                   flib_runtime_->GetFunctionLibraryDefinition());
   }
 
   // Report the error here if initialization failed.
   TF_RETURN_IF_ERROR(initialization_status_);
 
-  // Converts Tensorflow's graph control-flow constructs into functional
-  // control-flow that can be compiled into XLA code.
-  TF_RETURN_IF_ERROR(
-      FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
-                               graph.get(), local_flib_def_.get()));
-
   // Detect invalid nodes.
   // FunctionalizeControlFlow may remove some nodes from the graph.
   TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
@@ -758,23 +764,71 @@
       &options_.shape_representation_fn);
   core::ScopedUnref context_unref(context);
 
+  std::vector<XlaCompiler::Argument> real_args(args);
+  int token_input_index = -1;
+  if (options.add_token_input_output) {
+    // Add extra token input.
+    token_input_index = real_args.size();
+
+    XlaCompiler::Argument token_arg;
+    token_arg.kind = XlaCompiler::Argument::kToken;
+    real_args.push_back(token_arg);
+  }
+
   std::vector<XlaExpression> arg_expressions;
   std::vector<int> arg_cores;
-  TF_RETURN_IF_ERROR(
-      BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
-                     &arg_cores, &arg_expressions, &result->input_mapping,
-                     &result->xla_input_shapes, options.is_entry_computation));
+  TF_RETURN_IF_ERROR(BuildArguments(
+      *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+      &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
+      options.is_entry_computation));
   context->set_args(std::move(arg_expressions));
 
+  PushNodeTokenMapping();
+  // Use std::set instead of std::unordered_set to ensure determinism.
+  std::set<std::string> output_node_token_inputs;
+  if (token_input_index != -1) {
+    // Original token comes from input.
+    auto arg_expression = context->args()[token_input_index];
+    TF_RETURN_IF_ERROR(
+        SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
+
+    // Calculate token inputs for output token.
+    output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
+
+    // If there's no side-effecting op in the graph, use token input as token
+    // output.
+    if (output_node_token_inputs.empty()) {
+      output_node_token_inputs.insert(kXlaTokenArgNodeName);
+    }
+  } else if (options.is_entry_computation) {
+    // Original token is manually created.
+    if (HasSideEffectingNodes(*graph)) {
+      TF_RETURN_IF_ERROR(
+          SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
+    }
+  }
+
   TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
                                   flib_runtime_, NextStepId()));
+  if (token_input_index != -1) {
+    // Add extra token output.
+    std::vector<xla::XlaOp> token_inputs;
+    for (const auto& node_name : output_node_token_inputs) {
+      auto token_or = GetNodeToken(node_name);
+      TF_RETURN_IF_ERROR(token_or.status());
+      token_inputs.push_back(token_or.ValueOrDie());
+    }
+    TF_RETURN_IF_ERROR(
+        context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
+  }
+  TF_RETURN_IF_ERROR(PopNodeTokenMapping());
 
   int num_nonconst_outputs;
   int num_computation_outputs;
   result->computation = std::make_shared<xla::XlaComputation>();
   result->outputs.resize(context->retvals().size());
   TF_RETURN_IF_ERROR(BuildComputation(
-      args, arg_cores, context->retvals(), context->resources(),
+      real_args, arg_cores, context->retvals(), context->resources(),
       options.return_updated_values_for_all_resources,
       options.always_return_tuple, &builder, result->computation.get(),
       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@@ -913,4 +967,47 @@
   return Status::OK();
 }
 
+void XlaCompiler::PushNodeTokenMapping() {
+  node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
+}
+
+Status XlaCompiler::PopNodeTokenMapping() {
+  if (node_token_mapping_stack_.empty()) {
+    return errors::FailedPrecondition(
+        "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
+        "empty.");
+  }
+  node_token_mapping_stack_.pop();
+  return Status::OK();
+}
+
+Status XlaCompiler::SetNodeToken(const string& node_name,
+                                 const xla::XlaOp& op) {
+  if (node_token_mapping_stack_.empty()) {
+    return errors::FailedPrecondition(
+        "Calling SetNodeToken() when node_token_mapping_stack_ is "
+        "empty.");
+  }
+  auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
+  if (!insert_result.second) {
+    return errors::FailedPrecondition("Token mapping already exists for node ",
+                                      node_name);
+  }
+  return Status::OK();
+}
+
+xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
+  if (node_token_mapping_stack_.empty()) {
+    return errors::FailedPrecondition(
+        "Calling GetNodeToken() when node_token_mapping_stack_ is "
+        "empty.");
+  }
+  auto iter = node_token_mapping_stack_.top().find(node_name);
+  if (iter == node_token_mapping_stack_.top().end()) {
+    return errors::FailedPrecondition("Cannot find token mapping for node ",
+                                      node_name);
+  }
+  return iter->second;
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f4a985..2cc603a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@
 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
 
+#include <stack>
+
 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@
 
       // Argument is a run-time parameter.
       kParameter,
+
+      // Argument is an XLA token.
+      kToken,
     };
 
     Kind kind = kInvalid;
@@ -179,6 +185,9 @@
     // True when compiling the entry computation, false for subcomputations
     // (while, call, etc.)
     bool is_entry_computation = true;
+
+    // True when we should add XLA input & output to the graph/function.
+    bool add_token_input_output = false;
   };
 
   struct OutputDescription {
@@ -384,6 +393,11 @@
   xla::Client* client() const { return options_.client; }
   FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
 
+  void PushNodeTokenMapping();
+  Status PopNodeTokenMapping();
+  Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+  xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
  private:
   // Sets the function body `fbody` to the one registered as `function`.
   Status FindFunctionBody(const NameAttrList& function,
@@ -448,6 +462,15 @@
 
   std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
 
+  // This is used to store <node name, token output> mapping. Side-effecting
+  // ops call SetNodeToken() to record its token output, so later side-effecting
+  // ops can use GetNodeToken() to get it and use it as token input.
+  //
+  // It's a stack because we need a mapping like this for each level of nested
+  // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+  // stack, and pop the mapping before returning.
+  std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
 };
 
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be3c93a..72b17d0 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -20,10 +20,12 @@
 #include "tensorflow/cc/ops/function_ops.h"
 #include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -32,6 +34,7 @@
 #include "tensorflow/core/framework/common_shape_fns.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -205,27 +208,22 @@
                                      std::move(graph), args, &result));
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param0_literal =
-      xla::LiteralUtil::CreateR1<int32>({7, 42});
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_
           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected0 =
-      xla::LiteralUtil::CreateR1<int32>({4, 143});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({expected0.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+  xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 // Tests compilation of a graph where the _Retval node is not necessarily last
@@ -261,23 +259,20 @@
                                      args, &result));
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param0_literal =
-      xla::LiteralUtil::CreateR1<int32>({7, 42});
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_
           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal));
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
 }
 
 // Tests that the compiler doesn't reorder the parameters.
@@ -405,23 +400,19 @@
     EXPECT_FALSE(result.outputs[1].is_constant);
 
     // Tests that the generated computation works.
-    std::unique_ptr<xla::Literal> param0_literal =
-        xla::LiteralUtil::CreateR1<int32>({7, 42});
+    xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
     std::unique_ptr<xla::GlobalData> param0_data =
-        client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+        client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
     std::unique_ptr<xla::GlobalData> actual =
         client_->Execute(*result.computation, {param0_data.get()})
             .ConsumeValueOrDie();
-    std::unique_ptr<xla::Literal> actual_literal =
+    xla::Literal actual_literal =
         client_->Transfer(*actual).ConsumeValueOrDie();
 
-    std::unique_ptr<xla::Literal> expected0 =
-        xla::LiteralUtil::CreateR1<int32>({-7, -42});
-    std::unique_ptr<xla::Literal> expected_literal =
-        xla::LiteralUtil::MakeTuple({expected0.get()});
-    EXPECT_TRUE(
-        xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+    xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+    xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
+    EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
   }
 
   {
@@ -440,24 +431,21 @@
     EXPECT_FALSE(result.outputs[1].is_constant);
 
     // Tests that the generated computation works.
-    std::unique_ptr<xla::Literal> param0_literal =
-        xla::LiteralUtil::CreateR1<int32>({7, 42});
+    xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
     std::unique_ptr<xla::GlobalData> param0_data =
-        client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+        client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
     std::unique_ptr<xla::GlobalData> actual =
         client_->Execute(*result.computation, {param0_data.get()})
             .ConsumeValueOrDie();
-    std::unique_ptr<xla::Literal> actual_literal =
+    xla::Literal actual_literal =
         client_->Transfer(*actual).ConsumeValueOrDie();
 
-    std::unique_ptr<xla::Literal> expected0 =
-        xla::LiteralUtil::CreateR0<int32>(7);
-    std::unique_ptr<xla::Literal> expected1 =
-        xla::LiteralUtil::CreateR1<int32>({-7, -42});
-    std::unique_ptr<xla::Literal> expected =
-        xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
-    EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
+    xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
+    xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
+    xla::Literal expected =
+        xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+    EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
   }
 }
 
@@ -616,10 +604,17 @@
         auto instr1 = c1.instructions(j);
         auto instr2 = c2.instructions(j);
         instr1.clear_name();
+        instr1.clear_id();
+        instr1.clear_operand_ids();
         instr2.clear_name();
-        // The names of instructions were uniquified by the XlaBuilder, the rest
-        // of the fields should be identical.
+        instr2.clear_id();
+        instr2.clear_operand_ids();
+        // The names of instructions were uniquified by the XlaBuilder and the
+        // unique ids may be different, the rest of the fields should be
+        // identical.
         string str1, str2;
+        LOG(INFO) << "instr1 = " << instr1.DebugString();
+        LOG(INFO) << "instr2 = " << instr2.DebugString();
         instr1.AppendPartialToString(&str1);
         instr2.AppendPartialToString(&str2);
         EXPECT_EQ(str1, str2);
@@ -669,34 +664,26 @@
             update.tensor_array_gradients_accessed);
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> input_base =
-      xla::LiteralUtil::CreateR1<int32>({7, 42});
-  std::unique_ptr<xla::Literal> input_grad2 =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
-  std::unique_ptr<xla::Literal> input =
-      xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
+  xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+  xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*input).ConsumeValueOrDie();
+      client_->TransferToServer(input).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_->Execute(*result.computation, {param0_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> output_read =
-      xla::LiteralUtil::CreateR0<int32>(42);
-  std::unique_ptr<xla::Literal> output_base =
-      xla::LiteralUtil::CreateR1<int32>({7, 42});
-  std::unique_ptr<xla::Literal> output_grad1 =
-      xla::LiteralUtil::CreateR1<int32>({0, 1});
-  std::unique_ptr<xla::Literal> output_grad2 =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
-  std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
-      {output_base.get(), output_grad1.get(), output_grad2.get()});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
+  xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
+  xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
+  xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal output_resource =
+      xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
+  xla::Literal expected_literal =
+      xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 // Tests compilation and execution of a graph that adds two tensors.
@@ -863,29 +850,24 @@
 
 void RunAndCheckVariablesComputation(
     xla::Client* client, const XlaCompiler::CompilationResult& result) {
-  std::unique_ptr<xla::Literal> param0_literal =
-      xla::LiteralUtil::CreateR1<int32>({7, 42});
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
+  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<xla::GlobalData> param1_data =
-      client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client
           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected0 =
-      xla::LiteralUtil::CreateR1<int32>({5, 144});
-  std::unique_ptr<xla::Literal> expected1 =
-      xla::LiteralUtil::CreateR1<int32>({4, 143});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
+  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
+  xla::Literal expected_literal =
+      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 // Tests a simple graph that reads and writes a variable.
@@ -949,20 +931,17 @@
                                      std::move(graph), args, &result));
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::LiteralUtil::CreateR1<int32>({-3, 101});
+  xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
   std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_->Execute(*result.computation, {param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 TEST_F(XlaCompilerTest, ReturnResourceHandle) {
@@ -1066,29 +1045,27 @@
            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param0_literal =
+  xla::Literal param0_literal =
       xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
-  std::unique_ptr<xla::Literal> param1_literal =
+  xla::Literal param1_literal =
       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_
           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected0 =
+  xla::Literal expected0 =
       xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
-  std::unique_ptr<xla::Literal> expected1 =
-      xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+  xla::Literal expected_literal =
+      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
@@ -1135,29 +1112,26 @@
            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
 
   // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param0_literal =
+  xla::Literal param0_literal =
       xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
-  std::unique_ptr<xla::Literal> param1_literal =
+  xla::Literal param1_literal =
       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   std::unique_ptr<xla::GlobalData> actual =
       client_
           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
           .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
+  xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected0 =
-      xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
-  std::unique_ptr<xla::Literal> expected1 =
-      xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+  xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
+  xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
+  xla::Literal expected_literal =
+      xla::LiteralUtil::MakeTuple({&expected0, &expected1});
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
 }
 
 // Tests a graph which has a function with an invalid op.
@@ -1252,25 +1226,73 @@
     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
     CopyGraph(*graph, graph_copy.get());
     XlaCompiler::CompilationResult result;
-    status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
-                                   std::move(graph_copy), args, &result);
-    ASSERT_FALSE(status.ok());
-    EXPECT_TRUE(
-        absl::StrContains(status.error_message(),
-                          "The following nodes are unreachable "
-                          "from the source in the graph: {{node NoOp}}"))
-        << status.error_message();
-  }
-
-  // Fix control edges for NoOp.
-  {
-    std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
-    CopyGraph(*graph, graph_copy.get());
-    EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
-    XlaCompiler::CompilationResult result;
     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
                                        std::move(graph_copy), args, &result));
-    EXPECT_EQ(0, result.resource_updates.size());
+  }
+}
+
+class DummySideEffectingOp : public XlaOpKernel {
+ public:
+  explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+  void Compile(XlaOpKernelContext* ctx) override {
+    OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
+                            name(), xla::CreateToken(ctx->builder())));
+  }
+};
+
+REGISTER_OP("DummySideEffectingOp");
+
+REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
+
+TEST_F(XlaCompilerTest, TokenInputAndOutput) {
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  NodeDef side_effecting_op;
+  side_effecting_op.set_name("DummySideEffectingOp");
+  side_effecting_op.set_op("DummySideEffectingOp");
+  AddNodeAttr(kXlaTokenInputNodesAttrName,
+              std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
+  Status status;
+  graph->AddNode(side_effecting_op, &status);
+  TF_ASSERT_OK(status);
+  EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
+
+  const std::vector<XlaCompiler::Argument> empty_args;
+  {
+    // The case for entry computation: we don't add token input/output. Instead,
+    // we use CreateToken HLO to create the entry token.
+    XlaCompiler::CompileOptions options;
+    options.is_entry_computation = true;
+    options.add_token_input_output = false;
+    XlaCompiler compiler(DefaultOptions());
+
+    std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+    CopyGraph(*graph, graph_copy.get());
+    XlaCompiler::CompilationResult result;
+    TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+                                       empty_args, &result));
+    EXPECT_EQ(result.xla_input_shapes.size(), 0);
+    EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+    EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
+  }
+  {
+    // The case for non-entry computation (e.g. while loop body). We add token
+    // input/output.
+    XlaCompiler::CompileOptions options;
+    options.is_entry_computation = false;
+    options.add_token_input_output = true;
+    XlaCompiler compiler(DefaultOptions());
+
+    std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+    CopyGraph(*graph, graph_copy.get());
+    XlaCompiler::CompilationResult result;
+    TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+                                       empty_args, &result));
+    EXPECT_EQ(result.xla_input_shapes.size(), 1);
+    EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
+    EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+    EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
+    EXPECT_TRUE(xla::ShapeUtil::IsToken(
+        xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
   }
 }
 
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 24a4b92..f247570 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -32,7 +32,6 @@
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace tensorflow {
@@ -120,6 +119,17 @@
   return Status::OK();
 }
 
+Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
+  VLOG(1) << "Adding retval index " << retvals_.size()
+          << " with token to XLA computation";
+  XlaExpression e;
+  e.set_handle(token);
+  // We use DT_INVALID because there is no TF DataType which corresponds to XLA
+  // token. XlaCompiler handles this case separately, so putting it here is OK.
+  retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
+  return Status::OK();
+}
+
 xla::XlaBuilder* XlaContext::builder() { return builder_; }
 
 Status XlaContext::CreateResource(
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 4da8916..d7dbdc9 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -89,6 +89,9 @@
   // As for Retval, but for return values that are resource handles.
   Status AddResourceRetval(int retval_index, XlaResource* resource);
 
+  // As for Retval, but for return values that are XLA tokens.
+  Status AppendTokenRetval(const xla::XlaOp& token);
+
   // Creates a resource with resource `kind` and initial value `handle`. `name`
   // is a descriptive name for use in error messages. See the `XlaResource`
   // constructor for a description of the remaining arguments.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 1499c99..2a9eaee 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -67,7 +67,7 @@
   return GetComputationFromTensor(context_->input(index));
 }
 
-const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
   return GetComputationFromTensor(GetInputTensorByName(name));
 }
 
@@ -75,7 +75,7 @@
   return context_->input(index).shape();
 }
 
-TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
   return GetInputTensorByName(name).shape();
 }
 
@@ -83,6 +83,10 @@
   return context_->input(index).dtype();
 }
 
+DataType XlaOpKernelContext::InputType(absl::string_view name) {
+  return GetInputTensorByName(name).dtype();
+}
+
 xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
   xla::PrimitiveType type;
   Status status = DataTypeToPrimitiveType(input_type(index), &type);
@@ -100,7 +104,7 @@
 }
 
 static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
-                                     StringPiece name) {
+                                     absl::string_view name) {
   int start, stop;
   TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
   if (stop != start + 1) {
@@ -112,7 +116,7 @@
   return start;
 }
 
-Status XlaOpKernelContext::ConstantInput(StringPiece name,
+Status XlaOpKernelContext::ConstantInput(absl::string_view name,
                                          xla::Literal* constant_literal) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   return ConstantInput(index, constant_literal);
@@ -213,16 +217,15 @@
         context_->op_kernel().name(), " input ", index,
         ".\nError: ", constant_graph.status().error_message());
   }
-  xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
-      compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(),
-                                            &layout);
+  xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
+      constant_graph.ValueOrDie(), &layout);
   if (!computed.ok()) {
     return errors::Internal("Error evaluating ", context_->op_kernel().name(),
                             " input ", index,
-                            "as a compile-time constant.\nError: ",
+                            " as a compile-time constant.\nError: ",
                             computed.status().error_message());
   }
-  *constant_literal = std::move(*computed.ValueOrDie());
+  *constant_literal = std::move(computed).ValueOrDie();
 
   return Status::OK();
 }
@@ -265,7 +268,7 @@
   return LiteralToInt64Scalar(literal, out);
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name,
+Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name,
                                                     int64* out) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   return ConstantInputAsIntScalar(index, out);
@@ -305,7 +308,7 @@
   return LiteralToInt64Vector(literal, out);
 }
 
-Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name,
+Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name,
                                                     std::vector<int64>* out) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   return ConstantInputAsIntVector(index, out);
@@ -344,7 +347,7 @@
   }
 }
 
-Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name,
+Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name,
                                                        xla::Literal* out) {
   TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
   return ConstantInputAsInt64Literal(index, out);
@@ -361,7 +364,7 @@
   return Status::OK();
 }
 
-Status XlaOpKernelContext::InputList(StringPiece name,
+Status XlaOpKernelContext::InputList(absl::string_view name,
                                      std::vector<xla::XlaOp>* handles,
                                      std::vector<TensorShape>* shapes) {
   OpInputList inputs;
@@ -376,7 +379,7 @@
 }
 
 Status XlaOpKernelContext::ConstantInputList(
-    StringPiece name, std::vector<xla::Literal>* outputs) {
+    absl::string_view name, std::vector<xla::Literal>* outputs) {
   int start, stop;
   TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
   outputs->resize(stop - start);
@@ -429,8 +432,8 @@
                                  value);
 }
 
-Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
-                                             TensorShape* shape,
+Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
+                                             DataType type, TensorShape* shape,
                                              xla::XlaOp* value) {
   return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
                                  shape, value);
@@ -564,7 +567,7 @@
                               handle, builder());
 }
 
-Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
                                           xla::XlaOp handle) {
   TF_RET_CHECK(handle.valid());
   return AssignVariableTensor(GetInputTensorByName(name), type, context_,
@@ -610,7 +613,7 @@
   return XlaContext::Get(context_).GetOrCreateMul(type);
 }
 
-const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
   const Tensor* tensor;
   CHECK(context_->input(name, &tensor).ok());
   return *tensor;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 45cfa7d..a3a0d10 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -71,6 +71,9 @@
   // Returns the type of input `index`.
   DataType input_type(int index) const;
 
+  // Returns the type of input `name`.
+  DataType InputType(absl::string_view name);
+
   // Returns the type of input `index` as an xla::PrimitiveType. If the type
   // is not representable as an XLA type, sets an error status and returns
   // xla::PRIMITIVE_TYPE_INVALID.
@@ -79,15 +82,15 @@
   // Returns the shape of input `index`.
   TensorShape InputShape(int index);
 
-  // Returns the shape of input `name`.
-  TensorShape InputShape(StringPiece name);
+  // Returns the shape of input with name `name`.
+  TensorShape InputShape(absl::string_view name);
 
   // Returns input `index` as a XlaOp. Unlike
   // OpKernelContext::Input returns a symbolic value rather than a concrete
   // Tensor.
   const xla::XlaOp& Input(int index);
   // Returns input `name` as a XlaOp.
-  const xla::XlaOp& Input(StringPiece name);
+  const xla::XlaOp& Input(absl::string_view name);
 
   // Returns true if all inputs are the same shape, otherwise sets the
   // status to a non-OK value and returns false.
@@ -97,7 +100,7 @@
   // Returns the named list-valued immutable input in "list", as
   // defined in the OpDef.  If the named output is not list-valued,
   // returns a one-element list.
-  Status InputList(StringPiece name, std::vector<xla::XlaOp>* handles,
+  Status InputList(absl::string_view name, std::vector<xla::XlaOp>* handles,
                    std::vector<TensorShape>* shapes);
 
   // Helper methods for constant inputs.
@@ -106,7 +109,7 @@
   // expression cannot be evaluated, e.g., because it depends on unbound
   // parameters, returns a non-OK status.
   Status ConstantInput(int index, xla::Literal* constant_literal);
-  Status ConstantInput(StringPiece name, xla::Literal* constant_literal);
+  Status ConstantInput(absl::string_view name, xla::Literal* constant_literal);
 
   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
   // InputShape(index), and stores it in `*constant_literal`. If the input
@@ -118,14 +121,15 @@
 
   // Converts a constant scalar int32 or int64 tensor into an int64.
   Status ConstantInputAsIntScalar(int index, int64* out);
-  Status ConstantInputAsIntScalar(StringPiece name, int64* out);
+  Status ConstantInputAsIntScalar(absl::string_view name, int64* out);
 
   // Converts a constant scalar float32 or float64 tensor into a float64.
   Status ConstantInputAsFloatScalar(int index, double* out);
 
   // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
   Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
-  Status ConstantInputAsIntVector(StringPiece name, std::vector<int64>* out);
+  Status ConstantInputAsIntVector(absl::string_view name,
+                                  std::vector<int64>* out);
 
   // Reshapes and converts a constant int32 or int64 tensor into a vector of
   // int64s.
@@ -133,7 +137,7 @@
 
   // Converts a constant int32 or int64 Tensor into an xla int64 Literal.
   Status ConstantInputAsInt64Literal(int index, xla::Literal* out);
-  Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out);
+  Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out);
 
   // Converts a constant 1D int32 or int64 tensor into a TensorShape.
   Status ConstantInputAsShape(int index, TensorShape* shape);
@@ -141,7 +145,7 @@
   // Returns the named list-valued immutable input in "list", as
   // defined in the OpDef.  If the named output is not list-valued,
   // returns a one-element list.
-  Status ConstantInputList(StringPiece name,
+  Status ConstantInputList(absl::string_view name,
                            std::vector<xla::Literal>* literals);
 
   // Outputs
@@ -190,8 +194,8 @@
                            xla::XlaOp* value);
   // Reads the current value of the resouce variable referred to by input
   // `name`.
-  Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
-                           xla::XlaOp* value);
+  Status ReadVariableInput(absl::string_view name, DataType type,
+                           TensorShape* shape, xla::XlaOp* value);
 
   // Assigns the value `handle` to the variable referenced by input
   // `input_index`. The variable must be of `type`. Returns an error if the
@@ -199,7 +203,8 @@
   // different shape.
   Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
   // Assigns the value `handle` to the variable referenced by input `name`.
-  Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
+  Status AssignVariable(absl::string_view name, DataType type,
+                        xla::XlaOp handle);
 
   // Helper routines for the OP_REQUIRES macros
   void CtxFailure(const Status& s);
@@ -248,7 +253,7 @@
 
  private:
   // Returns the tensor of input `name`.
-  const Tensor& GetInputTensorByName(StringPiece name);
+  const Tensor& GetInputTensorByName(absl::string_view name);
 
   OpKernelContext* const context_;
 };
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index dae2d95..b0eeee3 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -371,26 +371,28 @@
   return *r;
 }
 
-XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) {
+XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) {
   registration_.reset(new XlaOpRegistry::OpRegistration);
   registration_->name = string(name);
 }
 
-XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) {
+XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
+    absl::string_view name) {
   XlaOpRegistrationBuilder registration(name);
   return registration;
 }
 
 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
-    absl::Span<const StringPiece> devices) {
+    absl::Span<const absl::string_view> devices) {
   registration_->has_device_whitelist = true;
-  for (StringPiece device : devices) {
+  for (absl::string_view device : devices) {
     registration_->device_whitelist.emplace(device);
   }
   return *this;
 }
 
-XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) {
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
+    absl::string_view device) {
   registration_->has_device_whitelist = true;
   registration_->device_whitelist.emplace(device);
   return *this;
@@ -407,7 +409,7 @@
 }
 
 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
-    StringPiece attr_name, DataType allowed) {
+    absl::string_view attr_name, DataType allowed) {
   std::set<DataType>& types =
       registration_->type_constraints[string(attr_name)];
   types.insert(allowed);
@@ -415,7 +417,7 @@
 }
 
 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
-    StringPiece attr_name, absl::Span<const DataType> allowed) {
+    absl::string_view attr_name, absl::Span<const DataType> allowed) {
   std::set<DataType>& types =
       registration_->type_constraints[string(attr_name)];
   for (DataType t : allowed) {
@@ -425,7 +427,7 @@
 }
 
 XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
-    StringPiece input_name) {
+    absl::string_view input_name) {
   registration_->compile_time_constant_inputs.emplace(input_name);
   return *this;
 }
@@ -452,7 +454,7 @@
 }
 
 XlaBackendRegistrar::XlaBackendRegistrar(
-    StringPiece name, absl::Span<const DataType> types,
+    absl::string_view name, absl::Span<const DataType> types,
     XlaOpRegistry::BackendOpFilter op_filter) {
   XlaOpRegistry& registry = XlaOpRegistry::Instance();
   registry.RegisterBackend(string(name), types, op_filter);
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index c640842..74a4885 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -232,18 +232,18 @@
 class XlaOpRegistrationBuilder {
  public:
   // Starts an operator registration chain.
-  static XlaOpRegistrationBuilder Name(StringPiece name);
+  static XlaOpRegistrationBuilder Name(absl::string_view name);
 
   // Specifies a whitelist of devices on which the operator may run.
-  XlaOpRegistrationBuilder& Device(StringPiece devices);
-  XlaOpRegistrationBuilder& Device(absl::Span<const StringPiece> devices);
+  XlaOpRegistrationBuilder& Device(absl::string_view devices);
+  XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
 
   // Specifies a type constraint for a type variable attribute. Each constraint
   // specifies the set of types that the type variable may assume.
-  XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
+  XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
                                            DataType allowed);
 
-  XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name,
+  XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
                                            absl::Span<const DataType> allowed);
 
   // Specifies that a dummy copy of this operator should not be registered on
@@ -254,13 +254,13 @@
   XlaOpRegistrationBuilder& AllowResourceTypes();
 
   // Mark 'input_name' as an argument whose value must be known at compile-time.
-  XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
+  XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
 
   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
       XlaOpRegistry::Factory factory);
 
  private:
-  XlaOpRegistrationBuilder(StringPiece name);
+  XlaOpRegistrationBuilder(absl::string_view name);
 
   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
 };
@@ -288,7 +288,7 @@
 
 class XlaBackendRegistrar {
  public:
-  XlaBackendRegistrar(StringPiece name, absl::Span<const DataType> types,
+  XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
 };
 
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 7928fa0..56c2e01 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -43,7 +43,7 @@
   for (const string& gradient : tensor_array_gradients) {
     tensor_array_gradients_[gradient].reset(new XlaResource(
         /*kind=*/kTensorArray, /*arg_num=*/-1,
-        /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_,
+        /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
         xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{}));
   }
 }
@@ -135,7 +135,7 @@
         xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
     gradient.reset(
         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
-                        /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
+                        /*name=*/absl::StrCat("TensorArrayGrad: ", name_),
                         type_, shape_, gradient_value, tensor_array_size_,
                         /*tensor_array_gradients=*/{}));
   }
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 8818f81..5dde5b4 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -37,8 +37,8 @@
 
 Client::~Client() = default;
 
-StatusOr<std::unique_ptr<Literal>> Client::Transfer(
-    const GlobalData& data, const Shape* shape_with_layout) {
+StatusOr<Literal> Client::Transfer(const GlobalData& data,
+                                   const Shape* shape_with_layout) {
   TransferToClientRequest request;
   *request.mutable_data() = data.handle();
   if (shape_with_layout != nullptr) {
@@ -114,7 +114,7 @@
   return Status::OK();
 }
 
-StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+StatusOr<Literal> Client::TransferFromOutfeed(
     const Shape* shape_with_layout, int64 replica_id,
     const DeviceHandle* device_handle) {
   TransferFromOutfeedRequest request;
@@ -162,7 +162,7 @@
   return Status::OK();
 }
 
-StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+StatusOr<Literal> Client::ExecuteAndTransfer(
     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
     const ExecutionOptions* execution_options,
     ExecutionProfile* execution_profile) {
@@ -177,8 +177,8 @@
   return Transfer(*data, shape_with_output_layout);
 }
 
-StatusOr<std::unique_ptr<Literal>> Client::ComputeConstant(
-    const XlaComputation& computation, const Layout* output_layout) const {
+StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
+                                          const Layout* output_layout) const {
   ComputeConstantGraphRequest request;
   *request.mutable_computation() = computation.proto();
   if (output_layout != nullptr) {
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 7960b07..6f4d33c 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -96,8 +96,8 @@
   //
   // If shape_with_layout is not nullptr, it points to a shape whose layout will
   // be the layout of the returned literal.
-  StatusOr<std::unique_ptr<Literal>> Transfer(
-      const GlobalData& data, const Shape* shape_with_layout = nullptr);
+  StatusOr<Literal> Transfer(const GlobalData& data,
+                             const Shape* shape_with_layout = nullptr);
 
   // Transfer the given literal to the server. This allocates memory on the
   // device and copies the literal's contents over. Returns a global data handle
@@ -122,7 +122,7 @@
   // device_handle and replica_id together specify a particular device; a device
   // assigned for the given replica_id among the replicas that the given device
   // handle belongs to.
-  StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+  StatusOr<Literal> TransferFromOutfeed(
       const Shape* shape_with_layout, int64 replica_id = 0,
       const DeviceHandle* device_handle = nullptr);
 
@@ -132,7 +132,7 @@
   // Executes the computation with the given arguments and transfers the result
   // to the client as a literal. Parameters are defined the same as for
   // Execute() and Transfer().
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+  StatusOr<Literal> ExecuteAndTransfer(
       const XlaComputation& computation,
       absl::Span<GlobalData* const> arguments,
       const ExecutionOptions* execution_options = nullptr,
@@ -153,7 +153,7 @@
   //
   // If output_layout is non-null, then the output of the computation will be
   // stored using that layout.
-  StatusOr<std::unique_ptr<Literal>> ComputeConstant(
+  StatusOr<Literal> ComputeConstant(
       const XlaComputation& computation,
       const Layout* output_layout = nullptr) const;
 
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 6861521..25cc37e 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -76,7 +76,7 @@
 std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
                                               Client* client) {
   if (DataSizeOfShape(shape) < (1LL << 20)) {
-    StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
+    StatusOr<Literal> literal_status = MakeFakeLiteral(shape);
     if (!literal_status.ok()) {
       // If we got an Unimplemented error, fall back to making the fake data via
       // an on-device computation.
@@ -84,7 +84,7 @@
                tensorflow::error::UNIMPLEMENTED);
       return MakeFakeDataViaDeviceOrDie(shape, client);
     }
-    return client->TransferToServer(*literal_status.ValueOrDie()).ValueOrDie();
+    return client->TransferToServer(literal_status.ValueOrDie()).ValueOrDie();
   }
 
   // If the data is large, generate it on-device.
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 4402ba8..f96b6c9 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -195,9 +195,8 @@
     HloSnapshot* hlo_snapshot) {
   hlo_snapshot->clear_arguments();
   for (const ShapedBuffer* argument : arguments) {
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
-                        LiteralFromShapedBuffer(*argument));
-    *hlo_snapshot->add_arguments() = literal->ToProto();
+    TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument));
+    *hlo_snapshot->add_arguments() = literal.ToProto();
   }
   return Status::OK();
 }
@@ -205,13 +204,12 @@
 Status LocalExecutable::RecordResult(const ShapedBuffer* result,
                                      HloSnapshot* hlo_snapshot) {
   hlo_snapshot->clear_result();
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
-                      LiteralFromShapedBuffer(*result));
-  *hlo_snapshot->mutable_result() = literal->ToProto();
+  TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result));
+  *hlo_snapshot->mutable_result() = literal.ToProto();
   return Status::OK();
 }
 
-StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
+StatusOr<Literal> LocalExecutable::LiteralFromShapedBuffer(
     const ShapedBuffer& shaped_buffer) {
   TF_ASSIGN_OR_RETURN(auto stream,
                       backend_->BorrowStream(shaped_buffer.device_ordinal()));
@@ -277,7 +275,7 @@
   return std::move(scoped_buffer);
 }
 
-StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
+StatusOr<Literal> LocalClient::ShapedBufferToLiteral(
     const ShapedBuffer& shaped_buffer) {
   TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
                                        shaped_buffer.device_ordinal()));
@@ -298,13 +296,13 @@
                                                                literal);
 }
 
-StatusOr<std::unique_ptr<Literal>> LocalClient::TransferFromOutfeedLocal(
-    const Shape& shape, int device_ordinal) {
+StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
+                                                        int device_ordinal) {
   TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
                       backend().stream_executor(device_ordinal));
   auto literal = Literal::CreateFromShape(shape);
   TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
-      executor, shape, literal.get()));
+      executor, shape, &literal));
   return std::move(literal);
 }
 
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 56c3a3d..feb2f8e 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -84,8 +84,7 @@
   Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
 
   // Returns a literal containing the contents of the given ShapedBuffer.
-  StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
-      const ShapedBuffer& shaped_buffer);
+  StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
 
   // The ordinal of the device which this executable was compiled for. The
   // executable can run on all equivalent devices (as determined by
@@ -132,8 +131,7 @@
 
   // Copy the data from the device contained in the given ShapedBuffer and
   // return as a Literal.
-  StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
-      const ShapedBuffer& shaped_buffer);
+  StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
 
   // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
   // as long as the handle is valid.
@@ -151,8 +149,8 @@
   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
   // not inherit from Client and there is no possibility of confusion with
   // Client::TransferFromOutfeed.
-  StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
-      const Shape& shape, int device_ordinal);
+  StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
+                                             int device_ordinal);
 
   // Returns the device ordinal that corresponds to the given replica number.
   //
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e639028..95ff643 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -134,11 +134,12 @@
 
 StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
   TF_RETURN_IF_ERROR(first_error_);
-  TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
+  TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
+                      LookUpInstructionByHandle(root_id));
 
   ProgramShape program_shape;
 
-  *program_shape.mutable_result() = instructions_[root_id].shape();
+  *program_shape.mutable_result() = root_proto->shape();
 
   // Check that the parameter numbers are continuous from 0, and add parameter
   // shapes and names to the program shape.
@@ -181,9 +182,8 @@
     return;
   }
 
-  CHECK(op_handle < instructions_.size() && op_handle >= 0);
-
-  const HloInstructionProto& instr = instructions_[op_handle];
+  const HloInstructionProto& instr =
+      *(LookUpInstructionByHandle(op_handle).ValueOrDie());
   const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
   switch (opcode) {
     default:
@@ -283,6 +283,7 @@
 
   // Clear data held by this builder.
   this->instructions_.clear();
+  this->handle_to_index_.clear();
   this->embedded_.clear();
   this->parameter_numbers_.clear();
 
@@ -738,7 +739,7 @@
   ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     *instr.mutable_shape() = ShapeUtil::MakeNil();
-    *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
+    *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
     return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
   });
 }
@@ -820,7 +821,7 @@
 }
 
 XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
-                      const PrecisionConfigProto* precision_config_proto) {
+                      const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
 
@@ -828,14 +829,13 @@
     dimension_numbers.add_lhs_contracting_dimensions(
         lhs_shape.dimensions_size() == 1 ? 0 : 1);
     dimension_numbers.add_rhs_contracting_dimensions(0);
-    return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto);
+    return DotGeneral(lhs, rhs, dimension_numbers, precision_config);
   });
 }
 
-XlaOp XlaBuilder::DotGeneral(
-    const XlaOp& lhs, const XlaOp& rhs,
-    const DotDimensionNumbers& dimension_numbers,
-    const PrecisionConfigProto* precision_config_proto) {
+XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+                             const DotDimensionNumbers& dimension_numbers,
+                             const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -844,8 +844,8 @@
                         ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
                                                         dimension_numbers));
     *instr.mutable_dot_dimension_numbers() = dimension_numbers;
-    if (precision_config_proto != nullptr) {
-      *instr.mutable_precision_config() = *precision_config_proto;
+    if (precision_config != nullptr) {
+      *instr.mutable_precision_config() = *precision_config;
     }
     return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
   });
@@ -899,28 +899,26 @@
 XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
                        absl::Span<const int64> window_strides, Padding padding,
                        int64 feature_group_count,
-                       const PrecisionConfigProto* precision_config_proto) {
+                       const PrecisionConfig* precision_config) {
   return ConvWithGeneralDimensions(
       lhs, rhs, window_strides, padding,
       CreateDefaultConvDimensionNumbers(window_strides.size()),
-      feature_group_count, precision_config_proto);
+      feature_group_count, precision_config);
 }
 
 XlaOp XlaBuilder::ConvWithGeneralPadding(
     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
     absl::Span<const std::pair<int64, int64>> padding,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
+    int64 feature_group_count, const PrecisionConfig* precision_config) {
   return ConvGeneral(lhs, rhs, window_strides, padding,
                      CreateDefaultConvDimensionNumbers(window_strides.size()),
-                     feature_group_count, precision_config_proto);
+                     feature_group_count, precision_config);
 }
 
 XlaOp XlaBuilder::ConvWithGeneralDimensions(
     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
+    int64 feature_group_count, const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
     TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -948,7 +946,7 @@
                        MakePadding(base_area_dimensions, window_dimensions,
                                    window_strides, padding),
                        dimension_numbers, feature_group_count,
-                       precision_config_proto);
+                       precision_config);
   });
 }
 
@@ -956,11 +954,10 @@
     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
     absl::Span<const std::pair<int64, int64>> padding,
     const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
+    int64 feature_group_count, const PrecisionConfig* precision_config) {
   return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
                             dimension_numbers, feature_group_count,
-                            precision_config_proto);
+                            precision_config);
 }
 
 XlaOp XlaBuilder::ConvGeneralDilated(
@@ -968,8 +965,7 @@
     absl::Span<const std::pair<int64, int64>> padding,
     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
     const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
+    int64 feature_group_count, const PrecisionConfig* precision_config) {
   return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
     HloInstructionProto instr;
     TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
@@ -990,14 +986,14 @@
 
     TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
                         ShapeInference::InferConvolveShape(
-                            lhs_shape, rhs_shape, instr.window(),
-                            dimension_numbers, feature_group_count));
+                            lhs_shape, rhs_shape, feature_group_count,
+                            instr.window(), dimension_numbers));
 
     *instr.mutable_convolution_dimension_numbers() = dimension_numbers;
     instr.set_feature_group_count(feature_group_count);
 
-    if (precision_config_proto != nullptr) {
-      *instr.mutable_precision_config() = *precision_config_proto;
+    if (precision_config != nullptr) {
+      *instr.mutable_precision_config() = *precision_config;
     }
 
     return AddInstruction(std::move(instr), HloOpcode::kConvolution,
@@ -2290,7 +2286,7 @@
   *program_shape->mutable_result() = root->shape();
 
   // We use std::set to keep the instruction ids in ascending order (which is
-  // also a valid denpendency order). The related ops will be added to the
+  // also a valid dependency order). The related ops will be added to the
   // subgraph in the same order.
   std::set<int64> related_ops;
   tensorflow::gtl::FlatSet<int64> related_calls;  // Related computations.
@@ -2298,14 +2294,16 @@
   worklist.push(root->id());
   related_ops.insert(root->id());
   while (!worklist.empty()) {
-    int64 node = worklist.front();
+    int64 handle = worklist.front();
     worklist.pop();
-    for (int64 id : instructions_[node].operand_ids()) {
+    TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+                        LookUpInstructionByHandle(handle));
+    for (int64 id : instr_proto->operand_ids()) {
       if (related_ops.insert(id).second) {
         worklist.push(id);
       }
     }
-    for (int64 called_id : instructions_[node].called_computation_ids()) {
+    for (int64 called_id : instr_proto->called_computation_ids()) {
       related_calls.insert(called_id);
     }
   }
@@ -2313,7 +2311,9 @@
   // Add related ops to the computation.
   for (int64 id : related_ops) {
     auto* instr = entry.add_instructions();
-    *instr = instructions_[id];
+    TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
+                        LookUpInstructionByHandle(id));
+    *instr = *instr_src;
     // Ensures that the instruction names are unique among the graph.
     const string& new_name =
         StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@@ -2420,11 +2420,11 @@
                                            absl::Span<const XlaOp> operands) {
   TF_RETURN_IF_ERROR(first_error_);
 
-  const int64 handle = instructions_.size();
+  const int64 handle = GetUniqueId();
   instr.set_id(handle);
   instr.set_opcode(HloOpcodeString(opcode));
   if (instr.name().empty()) {
-    instr.set_name(StrCat(instr.opcode()));
+    instr.set_name(instr.opcode());
   }
   for (const auto& operand : operands) {
     if (operand.builder_ == nullptr) {
@@ -2442,7 +2442,8 @@
     *instr.mutable_sharding() = *sharding_;
   }
 
-  instructions_.push_back(instr);
+  handle_to_index_[handle] = instructions_.size();
+  instructions_.push_back(std::move(instr));
 
   XlaOp op(handle, this);
   return op;
@@ -2472,10 +2473,16 @@
         op.handle(), op.builder_->name(), this->name());
   }
 
-  if (op.handle() >= instructions_.size() || op.handle() < 0) {
-    return InvalidArgument("no XlaOp value %d", op.handle());
+  return LookUpInstructionByHandle(op.handle());
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
+    int64 handle) const {
+  auto it = handle_to_index_.find(handle);
+  if (it == handle_to_index_.end()) {
+    return InvalidArgument("No XlaOp with handle %d", handle);
   }
-  return &instructions_[op.handle()];
+  return &instructions_[it->second];
 }
 
 // Enqueues a "retrieve parameter value" instruction for a parameter that was
@@ -2594,43 +2601,40 @@
 }
 
 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
-          const PrecisionConfigProto* precision_config_proto) {
-  return lhs.builder()->Dot(lhs, rhs, precision_config_proto);
+          const PrecisionConfig* precision_config) {
+  return lhs.builder()->Dot(lhs, rhs, precision_config);
 }
 
 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                  const DotDimensionNumbers& dimension_numbers,
-                 const PrecisionConfigProto* precision_config_proto) {
+                 const PrecisionConfig* precision_config) {
   return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
-                                   precision_config_proto);
+                                   precision_config);
 }
 
 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> window_strides, Padding padding,
-           int64 feature_group_count,
-           const PrecisionConfigProto* precision_config_proto) {
+           int64 feature_group_count, const PrecisionConfig* precision_config) {
   return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
-                             feature_group_count, precision_config_proto);
+                             feature_group_count, precision_config);
 }
 
-XlaOp ConvWithGeneralPadding(
-    const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
-    absl::Span<const std::pair<int64, int64>> padding,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
-  return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
-                                               padding, feature_group_count,
-                                               precision_config_proto);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+                             absl::Span<const int64> window_strides,
+                             absl::Span<const std::pair<int64, int64>> padding,
+                             int64 feature_group_count,
+                             const PrecisionConfig* precision_config) {
+  return lhs.builder()->ConvWithGeneralPadding(
+      lhs, rhs, window_strides, padding, feature_group_count, precision_config);
 }
 
 XlaOp ConvWithGeneralDimensions(
     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count,
-    const PrecisionConfigProto* precision_config_proto) {
+    int64 feature_group_count, const PrecisionConfig* precision_config) {
   return lhs.builder()->ConvWithGeneralDimensions(
       lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
-      precision_config_proto);
+      precision_config);
 }
 
 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
@@ -2638,10 +2642,10 @@
                   absl::Span<const std::pair<int64, int64>> padding,
                   const ConvolutionDimensionNumbers& dimension_numbers,
                   int64 feature_group_count,
-                  const PrecisionConfigProto* precision_config_proto) {
+                  const PrecisionConfig* precision_config) {
   return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
                                     dimension_numbers, feature_group_count,
-                                    precision_config_proto);
+                                    precision_config);
 }
 
 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
@@ -2651,10 +2655,10 @@
                          absl::Span<const int64> rhs_dilation,
                          const ConvolutionDimensionNumbers& dimension_numbers,
                          int64 feature_group_count,
-                         const PrecisionConfigProto* precision_config_proto) {
+                         const PrecisionConfig* precision_config) {
   return lhs.builder()->ConvGeneralDilated(
       lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
-      dimension_numbers, feature_group_count, precision_config_proto);
+      dimension_numbers, feature_group_count, precision_config);
 }
 
 XlaOp Fft(const XlaOp& operand, FftType fft_type,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 59fbc66..d0c59fa 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/stacktrace.h"
@@ -496,20 +497,19 @@
 
   // Enqueues a dot instruction onto the computation.
   XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
-            const PrecisionConfigProto* precision_config_proto = nullptr);
+            const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a general dot instruction onto the computation.
-  XlaOp DotGeneral(
-      const XlaOp& lhs, const XlaOp& rhs,
-      const DotDimensionNumbers& dimension_numbers,
-      const PrecisionConfigProto* precision_config_proto = nullptr);
+  XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+                   const DotDimensionNumbers& dimension_numbers,
+                   const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a convolution instruction onto the computation, which uses the
   // default convolution dimension numbers.
   XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
              absl::Span<const int64> window_strides, Padding padding,
              int64 feature_group_count = 1,
-             const PrecisionConfigProto* precision_config_proto = nullptr);
+             const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a convolution instruction onto the computation, with the caller
   // provided padding configuration in the format returned by MakePadding().
@@ -518,7 +518,7 @@
       absl::Span<const int64> window_strides,
       absl::Span<const std::pair<int64, int64>> padding,
       int64 feature_group_count = 1,
-      const PrecisionConfigProto* precision_config_proto = nullptr);
+      const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a convolution instruction onto the computation, with the caller
   // provided dimension numbers configuration.
@@ -527,29 +527,27 @@
       absl::Span<const int64> window_strides, Padding padding,
       const ConvolutionDimensionNumbers& dimension_numbers,
       int64 feature_group_count = 1,
-      const PrecisionConfigProto* precision_config_proto = nullptr);
+      const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a convolution instruction onto the computation, with the caller
   // provided padding configuration as well as the dimension numbers.
-  XlaOp ConvGeneral(
-      const XlaOp& lhs, const XlaOp& rhs,
-      absl::Span<const int64> window_strides,
-      absl::Span<const std::pair<int64, int64>> padding,
-      const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count = 1,
-      const PrecisionConfigProto* precision_config_proto = nullptr);
+  XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+                    absl::Span<const int64> window_strides,
+                    absl::Span<const std::pair<int64, int64>> padding,
+                    const ConvolutionDimensionNumbers& dimension_numbers,
+                    int64 feature_group_count = 1,
+                    const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues a convolution instruction onto the computation, with the caller
   // provided padding configuration, dilation factors and dimension numbers.
-  XlaOp ConvGeneralDilated(
-      const XlaOp& lhs, const XlaOp& rhs,
-      absl::Span<const int64> window_strides,
-      absl::Span<const std::pair<int64, int64>> padding,
-      absl::Span<const int64> lhs_dilation,
-      absl::Span<const int64> rhs_dilation,
-      const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count = 1,
-      const PrecisionConfigProto* precision_config_proto = nullptr);
+  XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+                           absl::Span<const int64> window_strides,
+                           absl::Span<const std::pair<int64, int64>> padding,
+                           absl::Span<const int64> lhs_dilation,
+                           absl::Span<const int64> rhs_dilation,
+                           const ConvolutionDimensionNumbers& dimension_numbers,
+                           int64 feature_group_count = 1,
+                           const PrecisionConfig* precision_config = nullptr);
 
   // Enqueues an FFT instruction onto the computation, of the given type and
   // with the given FFT length.
@@ -958,6 +956,8 @@
                             HloInstructionProto* instr);
 
   StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+  StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+      int64 handle) const;
 
   // Internal helper method that does the building for an arbitrary unary op.
   XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1027,6 +1027,10 @@
   // The instructions of this computation.
   std::vector<HloInstructionProto> instructions_;
 
+  // A map from XlaOp::Handle to the index in the instructions_ vector where the
+  // instruction is held.
+  tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
   // The embedded computations used by this computation. Each computation was
   // the entry computation of some XlaComputation, the key is the unique id of
   // that XlaComputation.
@@ -1150,32 +1154,30 @@
   friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
                   absl::Span<const int64> broadcast_dimensions);
   friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
-                   const PrecisionConfigProto* precision_config_proto);
+                   const PrecisionConfig* precision_config);
   friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                           const DotDimensionNumbers& dimension_number,
-                          const PrecisionConfigProto* precision_config_proto);
+                          const PrecisionConfig* precision_config);
   friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
                     absl::Span<const int64> window_strides, Padding padding,
                     int64 feature_group_count,
-                    const PrecisionConfigProto* precision_config_proto);
+                    const PrecisionConfig* precision_config);
   friend XlaOp ConvWithGeneralPadding(
       const XlaOp& lhs, const XlaOp& rhs,
       absl::Span<const int64> window_strides,
       absl::Span<const std::pair<int64, int64>> padding,
-      int64 feature_group_count,
-      const PrecisionConfigProto* precision_config_proto);
+      int64 feature_group_count, const PrecisionConfig* precision_config);
   friend XlaOp ConvWithGeneralDimensions(
       const XlaOp& lhs, const XlaOp& rhs,
       absl::Span<const int64> window_strides, Padding padding,
       const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count,
-      const PrecisionConfigProto* precision_config_proto);
+      int64 feature_group_count, const PrecisionConfig* precision_config);
   friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
                            absl::Span<const int64> window_strides,
                            absl::Span<const std::pair<int64, int64>> padding,
                            const ConvolutionDimensionNumbers& dimension_numbers,
                            int64 feature_group_count,
-                           const PrecisionConfigProto* precision_config_proto);
+                           const PrecisionConfig* precision_config);
   friend XlaOp ConvGeneralDilated(
       const XlaOp& lhs, const XlaOp& rhs,
       absl::Span<const int64> window_strides,
@@ -1183,8 +1185,7 @@
       absl::Span<const int64> lhs_dilation,
       absl::Span<const int64> rhs_dilation,
       const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count,
-      const PrecisionConfigProto* precision_config_proto);
+      int64 feature_group_count, const PrecisionConfig* precision_config);
   friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
                    absl::Span<const int64> fft_length);
   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
@@ -1629,27 +1630,27 @@
 
 // Enqueues a dot instruction onto the computation.
 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
-          const PrecisionConfigProto* precision_config_proto = nullptr);
+          const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a general dot instruction onto the computation.
 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
                  const DotDimensionNumbers& dimension_numbers,
-                 const PrecisionConfigProto* precision_config_proto = nullptr);
+                 const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a convolution instruction onto the computation, which uses the
 // default convolution dimension numbers.
 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
            absl::Span<const int64> window_strides, Padding padding,
            int64 feature_group_count = 1,
-           const PrecisionConfigProto* precision_config_proto = nullptr);
+           const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration in the format returned by MakePadding().
-XlaOp ConvWithGeneralPadding(
-    const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
-    absl::Span<const std::pair<int64, int64>> padding,
-    int64 feature_group_count = 1,
-    const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
+                             absl::Span<const int64> window_strides,
+                             absl::Span<const std::pair<int64, int64>> padding,
+                             int64 feature_group_count = 1,
+                             const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided dimension numbers configuration.
@@ -1657,7 +1658,7 @@
     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
     int64 feature_group_count = 1,
-    const PrecisionConfigProto* precision_config_proto = nullptr);
+    const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration as well as the dimension numbers.
@@ -1666,17 +1667,18 @@
                   absl::Span<const std::pair<int64, int64>> padding,
                   const ConvolutionDimensionNumbers& dimension_numbers,
                   int64 feature_group_count = 1,
-                  const PrecisionConfigProto* precision_config_proto = nullptr);
+                  const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues a convolution instruction onto the computation, with the caller
 // provided padding configuration, dilation factors and dimension numbers.
-XlaOp ConvGeneralDilated(
-    const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
-    absl::Span<const std::pair<int64, int64>> padding,
-    absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
-    const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count = 1,
-    const PrecisionConfigProto* precision_config_proto = nullptr);
+XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
+                         absl::Span<const int64> window_strides,
+                         absl::Span<const std::pair<int64, int64>> padding,
+                         absl::Span<const int64> lhs_dilation,
+                         absl::Span<const int64> rhs_dilation,
+                         const ConvolutionDimensionNumbers& dimension_numbers,
+                         int64 feature_group_count = 1,
+                         const PrecisionConfig* precision_config = nullptr);
 
 // Enqueues an FFT instruction onto the computation, of the given type and
 // with the given FFT length.
@@ -2117,12 +2119,12 @@
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR0(NativeT value) {
-  return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
+  return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
-  return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
+  return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
 }
 
 template <typename NativeT>
@@ -2134,44 +2136,44 @@
 }
 
 inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
-  return ConstantLiteral(*LiteralUtil::CreateR1(values));
+  return ConstantLiteral(LiteralUtil::CreateR1(values));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR2(
     std::initializer_list<std::initializer_list<NativeT>> values) {
-  return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
+  return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
                                               const Layout& layout) {
   return ConstantLiteral(
-      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
-  return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
+  return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
     const Array2D<NativeT>& values, const Layout& layout) {
   return ConstantLiteral(
-      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+      LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
-  return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+  return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 }
 
 template <typename NativeT>
 XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
     const Array3D<NativeT>& values, const Layout& layout) {
   return ConstantLiteral(
-      *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
@@ -2194,12 +2196,12 @@
 
 template <typename NativeT>
 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
-  return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+  return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
 }
 
 template <typename NativeT>
 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
-  return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+  return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
 }
 
 template <typename NativeT>
@@ -2212,13 +2214,13 @@
 
 inline XlaOp ConstantR1(XlaBuilder* builder,
                         const tensorflow::core::Bitmap& values) {
-  return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+  return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
 }
 
 template <typename NativeT>
 XlaOp ConstantR2(XlaBuilder* builder,
                  std::initializer_list<std::initializer_list<NativeT>> values) {
-  return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+  return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
 }
 
 template <typename NativeT>
@@ -2226,14 +2228,13 @@
                                   const Array<NativeT>& values,
                                   const Layout& layout) {
   return ConstantLiteral(
-      builder,
-      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
   return ConstantLiteral(builder,
-                         *LiteralUtil::CreateFromArray<NativeT>(values));
+                         LiteralUtil::CreateFromArray<NativeT>(values));
 }
 
 template <typename NativeT>
@@ -2241,15 +2242,14 @@
                                       const Array2D<NativeT>& values,
                                       const Layout& layout) {
   return ConstantLiteral(
-      builder,
-      *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+      builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
                             const Array2D<NativeT>& values) {
   return ConstantLiteral(builder,
-                         *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+                         LiteralUtil::CreateR2FromArray2D<NativeT>(values));
 }
 
 template <typename NativeT>
@@ -2258,7 +2258,7 @@
                                       const Layout& layout) {
   return ConstantLiteral(
       builder,
-      *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+      LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
 }
 
 template <typename NativeT>
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 3f7635b..5035f41 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -174,9 +174,9 @@
   return *this;
 }
 
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
-  auto literal = absl::make_unique<Literal>(shape);
-  literal->root_piece_->ForEachMutableSubpiece(
+Literal LiteralBase::CreateFromShape(const Shape& shape) {
+  Literal literal(shape);
+  literal.root_piece_->ForEachMutableSubpiece(
       [&](const ShapeIndex& index, Piece* piece) {
         if (ShapeUtil::IsArray(piece->subshape())) {
           memset(piece->untyped_data(), 0, piece->size_bytes());
@@ -278,8 +278,8 @@
   return Status::OK();
 }
 
-/* static */ StatusOr<std::unique_ptr<Literal>>
-MutableLiteralBase::CreateFromProto(const LiteralProto& proto) {
+/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
+    const LiteralProto& proto) {
   if (!proto.has_shape()) {
     return InvalidArgument("LiteralProto has no shape");
   }
@@ -287,9 +287,9 @@
     return InvalidArgument("LiteralProto has no layout");
   }
 
-  auto literal = absl::make_unique<Literal>(proto.shape());
+  Literal literal(proto.shape());
 
-  TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+  TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
       [&](const ShapeIndex& index, Piece* piece) {
         const LiteralProto* proto_element = &proto;
         for (int64 i : index) {
@@ -556,38 +556,37 @@
   }
 }
 
-std::unique_ptr<Literal> LiteralBase::Relayout(
-    const Layout& new_layout, const ShapeIndex& shape_index) const {
+Literal LiteralBase::Relayout(const Layout& new_layout,
+                              const ShapeIndex& shape_index) const {
   // Create new shape with 'new_layout' set at the given shape index.
   Shape new_shape = shape();
   Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
   TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
   *subshape->mutable_layout() = new_layout;
-  auto result = absl::make_unique<Literal>(new_shape);
-  TF_CHECK_OK(result->CopyFrom(*this));
+  Literal result(new_shape);
+  TF_CHECK_OK(result.CopyFrom(*this));
   return result;
 }
 
-std::unique_ptr<Literal> LiteralBase::Relayout(
-    const Shape& shape_with_layout) const {
+Literal LiteralBase::Relayout(const Shape& shape_with_layout) const {
   CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
       << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
       << " not compatible with literal shape "
       << ShapeUtil::HumanString(shape());
-  std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+  Literal result = CreateFromShape(shape_with_layout);
   ShapeUtil::ForEachSubshape(
-      result->shape(),
+      result.shape(),
       [this, &result](const Shape& subshape, const ShapeIndex& index) {
         if (ShapeUtil::IsArray(subshape)) {
-          TF_CHECK_OK(result->CopyFrom(*this,
-                                       /*dest_shape_index=*/index,
-                                       /*src_shape_index=*/index));
+          TF_CHECK_OK(result.CopyFrom(*this,
+                                      /*dest_shape_index=*/index,
+                                      /*src_shape_index=*/index));
         }
       });
   return result;
 }
 
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+StatusOr<Literal> LiteralBase::Broadcast(
     const Shape& result_shape, absl::Span<const int64> dimensions) const {
   if (!ShapeUtil::IsArray(shape())) {
     return InvalidArgument("Broadcast only supports arrays.");
@@ -598,14 +597,14 @@
                  result_shape.dimensions(dimensions[i]));
   }
 
-  std::unique_ptr<Literal> result = absl::make_unique<Literal>(result_shape);
+  Literal result(result_shape);
 
   // scratch_source_index is temporary storage space for the computed index into
   // the input literal.  We put it here to avoid allocating an std::vector in
   // every iteration of ShapeUtil::ForEachIndex.
   std::vector<int64> scratch_source_index(shape().dimensions_size());
 
-  char* dest_data = static_cast<char*>(result->untyped_data());
+  char* dest_data = static_cast<char*>(result.untyped_data());
   const char* source_data = static_cast<const char*>(untyped_data());
   const int64 primitive_size =
       ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
@@ -627,37 +626,36 @@
   return std::move(result);
 }
 
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+StatusOr<Literal> LiteralBase::Reshape(
     absl::Span<const int64> dimensions) const {
   if (!ShapeUtil::IsArray(shape())) {
     return InvalidArgument("Reshape does not support tuples.");
   }
-  std::unique_ptr<Literal> output;
+  Literal output;
   if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
     output =
         Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
   } else {
-    output = CloneToUnique();
+    output = Clone();
   }
   // Because the layout is monotonic, we can simply reuse the same sequence of
   // values without changing their order.
-  *output->mutable_shape_do_not_use() =
+  *output.mutable_shape_do_not_use() =
       ShapeUtil::MakeShape(shape().element_type(), dimensions);
 
   int64 elements_before = ShapeUtil::ElementsIn(shape());
-  int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+  int64 elements_after = ShapeUtil::ElementsIn(output.shape());
   if (elements_before != elements_after) {
     return InvalidArgument(
         "Shapes before and after Literal::Reshape have different numbers "
         "of elements: %s vs %s.",
         ShapeUtil::HumanString(shape()),
-        ShapeUtil::HumanString(output->shape()));
+        ShapeUtil::HumanString(output.shape()));
   }
   return std::move(output);
 }
 
-std::unique_ptr<Literal> LiteralBase::Transpose(
-    absl::Span<const int64> permutation) const {
+Literal LiteralBase::Transpose(absl::Span<const int64> permutation) const {
   CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
   CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
       << "Given permutation is not a permutation of dimension numbers";
@@ -687,32 +685,31 @@
   for (auto index : LayoutUtil::MinorToMajor(shape())) {
     layout->add_minor_to_major(inverse_permutation[index]);
   }
-  auto new_literal = absl::make_unique<Literal>(permuted_shape);
-  DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+  Literal new_literal(permuted_shape);
+  DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()),
             ShapeUtil::ByteSizeOf(shape()));
-  std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+  std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes());
   return new_literal;
 }
 
 template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::SliceInternal(
+Literal LiteralBase::SliceInternal(
     const Shape& result_shape, absl::Span<const int64> start_indices) const {
-  auto result_literal = absl::make_unique<Literal>(result_shape);
+  Literal result_literal(result_shape);
   DimensionVector new_indices(ShapeUtil::Rank(result_shape));
-  result_literal->EachCell<NativeT>(
+  result_literal.EachCell<NativeT>(
       [&](absl::Span<const int64> indices, NativeT /*value*/) {
         for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
           new_indices[i] = indices[i] + start_indices[i];
         }
         NativeT value = Get<NativeT>(new_indices);
-        result_literal->Set<NativeT>(indices, value);
+        result_literal.Set<NativeT>(indices, value);
       });
   return result_literal;
 }
 
-std::unique_ptr<Literal> LiteralBase::Slice(
-    absl::Span<const int64> start_indices,
-    absl::Span<const int64> limit_indices) const {
+Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
+                           absl::Span<const int64> limit_indices) const {
   CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
 
   DimensionVector result_dimensions;
@@ -750,12 +747,6 @@
   return result;
 }
 
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
-  auto result = absl::make_unique<Literal>(shape());
-  TF_CHECK_OK(result->CopyFrom(*this));
-  return result;
-}
-
 string LiteralBase::GetAsString(absl::Span<const int64> multi_index,
                                 const ShapeIndex& shape_index) const {
   const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
@@ -1191,14 +1182,14 @@
 
 namespace {
 template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
-    const LiteralBase& src_literal, const ConverterType& converter) {
+Literal ConvertBetweenNativeTypesWithConverter(const LiteralBase& src_literal,
+                                               const ConverterType& converter) {
   CHECK(ShapeUtil::IsArray(src_literal.shape()));
-  auto result_literal = absl::make_unique<Literal>(ShapeUtil::ChangeElementType(
+  Literal result_literal(ShapeUtil::ChangeElementType(
       src_literal.shape(),
       primitive_util::NativeToPrimitiveType<NativeDestT>()));
   auto src_data = src_literal.data<NativeSrcT>();
-  auto dest_data = result_literal->template data<NativeDestT>();
+  auto dest_data = result_literal.template data<NativeDestT>();
   int64 num_elements = src_literal.element_count();
 
   for (int64 i = 0; i < num_elements; ++i) {
@@ -1208,8 +1199,7 @@
 }
 
 template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
-    const LiteralBase& src_literal) {
+Literal ConvertBetweenNativeTypes(const LiteralBase& src_literal) {
   auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
       src_literal, converter);
@@ -1217,7 +1207,7 @@
 
 template <typename NativeSrcT, typename NativeDestT>
 typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
-                        std::unique_ptr<Literal>>::type
+                        Literal>::type
 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
   auto converter = [](NativeSrcT src) {
     return tensorflow::bit_cast<NativeDestT>(src);
@@ -1232,20 +1222,20 @@
 // identical sizes higher up.
 template <typename NativeSrcT, typename NativeDestT>
 typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
-                        std::unique_ptr<Literal>>::type
+                        Literal>::type
 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
   LOG(FATAL) << "Invalid bitcast between types of different sizes.";
 }
 
 template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+Literal ConvertToC64(const LiteralBase& src_literal) {
   CHECK(ShapeUtil::IsArray(src_literal.shape()));
-  auto result_literal = absl::make_unique<Literal>(
+  Literal result_literal(
       ShapeUtil::ChangeElementType(src_literal.shape(), C64));
   using NativeSrcT =
       typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
   absl::Span<const NativeSrcT> src_data = src_literal.data<NativeSrcT>();
-  absl::Span<complex64> dest_data = result_literal->data<complex64>();
+  absl::Span<complex64> dest_data = result_literal.data<complex64>();
   int64 num_elements = src_literal.element_count();
   for (int64 i = 0; i < num_elements; ++i) {
     dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
@@ -1254,8 +1244,7 @@
 }
 
 template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
-                                             bool bitcast) {
+Literal ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) {
   CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
   if (bitcast) {
     return BitcastBetweenNativeTypes<
@@ -1273,9 +1262,9 @@
 }
 
 template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
-    const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
-    bool bitcast) {
+StatusOr<Literal> ConvertIfDestTypeMatches(const LiteralBase& src_literal,
+                                           PrimitiveType primitive_dest_type,
+                                           bool bitcast) {
   switch (primitive_dest_type) {
 #define CONVERT_IF_TYPES_MATCH(type)                                    \
   case (type):                                                          \
@@ -1307,12 +1296,12 @@
                        PrimitiveType_Name(primitive_dest_type));
 }
 
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
-    const LiteralBase& literal, PrimitiveType primitive_dest_type,
-    bool bitcast) {
+StatusOr<Literal> ConvertSwitch(const LiteralBase& literal,
+                                PrimitiveType primitive_dest_type,
+                                bool bitcast) {
   TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
   if (literal.shape().element_type() == primitive_dest_type) {
-    return literal.CloneToUnique();
+    return literal.Clone();
   }
   switch (literal.shape().element_type()) {
 #define CONVERT_IF_DEST_TYPE_MATCHES(type)                                \
@@ -1342,12 +1331,12 @@
 
 }  // namespace
 
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+StatusOr<Literal> LiteralBase::Convert(
     PrimitiveType primitive_dest_type) const {
   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
 }
 
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+StatusOr<Literal> LiteralBase::BitcastConvert(
     PrimitiveType primitive_dest_type) const {
   if (primitive_util::BitWidth(shape().element_type()) !=
       primitive_util::BitWidth(primitive_dest_type)) {
@@ -1362,17 +1351,8 @@
   return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
 }
 
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
-    const Shape& dest_shape, bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
   if (!ShapeUtil::IsTuple(dest_shape)) {
-    if (round_f32_to_bf16 && shape().element_type() == F32 &&
-        dest_shape.element_type() == BF16) {
-      auto converter = [](float src) {
-        return tensorflow::bfloat16::round_to_bfloat16(src);
-      };
-      return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
-                                                                     converter);
-    }
     return Convert(dest_shape.element_type());
   }
   std::vector<Literal> elements;
@@ -1381,11 +1361,9 @@
     TF_ASSIGN_OR_RETURN(
         auto new_element,
         element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
-    elements.push_back(std::move(*new_element));
+    elements.push_back(std::move(new_element));
   }
-  auto converted = absl::make_unique<Literal>();
-  *converted = MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
-  return std::move(converted);
+  return MutableLiteralBase::MoveIntoTuple(absl::MakeSpan(elements));
 }
 
 /* static */ Literal MutableLiteralBase::MoveIntoTuple(
@@ -1782,6 +1760,10 @@
     case PRED:
       CopyToRepeatedField(proto->mutable_preds(), data<bool>());
       break;
+    case S8:
+      proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
+                     element_count());
+      break;
     case U8:
       proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
                      element_count());
@@ -1872,6 +1854,11 @@
     case PRED:
       TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
       break;
+    case S8: {
+      auto s8_data = data<int8>();
+      TF_RET_CHECK(proto.s8s().size() == s8_data.size());
+      std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
+    } break;
     case U8: {
       auto u8_data = data<uint8>();
       TF_RET_CHECK(proto.u8s().size() == u8_data.size());
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index b928cb6..1e0a2ad 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -217,31 +217,20 @@
 
   // Converts this literal to the given shape. Returns an error is the
   // conversion is not possible.
-  //
-  // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
-  // instead of truncation; otherwise, truncation is used.
-  //
-  // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
-  // the default behavior.
-  StatusOr<std::unique_ptr<Literal>> ConvertToShape(
-      const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+  StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
 
   // Converts this literal to another primitive type using a bitcast
   // conversion. The to and from primitive types must have the same bit
   // width. Returns an error if the conversion is not possible. This literal
   // must be array-shaped.
-  StatusOr<std::unique_ptr<Literal>> BitcastConvert(
-      PrimitiveType primitive_dest_type) const;
+  StatusOr<Literal> BitcastConvert(PrimitiveType primitive_dest_type) const;
 
   // Converts this literal to another primitive type. Returns an error if the
   // conversion is not possible. This literal must be array-shaped.
-  StatusOr<std::unique_ptr<Literal>> Convert(
-      PrimitiveType primitive_dest_type) const;
+  StatusOr<Literal> Convert(PrimitiveType primitive_dest_type) const;
 
-  // Clones the underlying buffers into a new Literal, or new
-  // std::unique_ptr<Literal>.
+  // Clones the underlying buffers into a new Literal.
   Literal Clone() const;
-  std::unique_ptr<Literal> CloneToUnique() const;
 
   // TODO(b/67651157): The methods below which perform computation on Literals
   // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
@@ -259,24 +248,23 @@
   // Note: this is useful when the client wants to ensure that a value placed in
   // the XLA allocation tracker has a particular layout; for efficiency
   // purposes or avoiding unimplemented operation/layout combinations.
-  std::unique_ptr<Literal> Relayout(const Layout& new_layout,
-                                    const ShapeIndex& shape_index = {}) const;
+  Literal Relayout(const Layout& new_layout,
+                   const ShapeIndex& shape_index = {}) const;
 
   // An overload of Relayout which changes the layout of the entire shape rather
   // than being limited to a single array within the shape.
-  std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+  Literal Relayout(const Shape& shape_with_layout) const;
 
   // Creates a new literal by reshaping this literal to have the given
   // dimensions. The total number of elements must not change; The
   // implementation currently only supports monotonic dim0-major layouts.
   // This literal must be an array.
-  StatusOr<std::unique_ptr<Literal>> Reshape(
-      absl::Span<const int64> dimensions) const;
+  StatusOr<Literal> Reshape(absl::Span<const int64> dimensions) const;
 
   // Creates a new literal by broadcasting this literal with `dimensions` to
   // yield a literal of shape `result_shape`.
-  StatusOr<std::unique_ptr<Literal>> Broadcast(
-      const Shape& result_shape, absl::Span<const int64> dimensions) const;
+  StatusOr<Literal> Broadcast(const Shape& result_shape,
+                              absl::Span<const int64> dimensions) const;
 
   // Creates a new literal by reordering the dimensions of this literal.
   // The given `permutation` must be a permutation of the dimension numbers
@@ -285,7 +273,7 @@
   // For example, a transpose call on a literal of shape [3 x 8 x 4] and
   // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
   // This literal must be an array.
-  std::unique_ptr<Literal> Transpose(absl::Span<const int64> permutation) const;
+  Literal Transpose(absl::Span<const int64> permutation) const;
 
   // Creates a sub-array from this literal by extracting the indices
   // [start_index, limit_index) of each dimension. The result literal has the
@@ -293,15 +281,15 @@
   // start_indices and limit_indices must be the rank of the literal, and the
   // indices follow the order of the dimensions.
   // This literal must be an array.
-  std::unique_ptr<Literal> Slice(absl::Span<const int64> start_indices,
-                                 absl::Span<const int64> limit_indices) const;
+  Literal Slice(absl::Span<const int64> start_indices,
+                absl::Span<const int64> limit_indices) const;
 
   // Creates a literal with a prepended dimension with bound "times"; e.g. a
   // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
   // literal replicated four times.
   // This literal must be an array.
   template <typename NativeT>
-  std::unique_ptr<Literal> Replicate(int64 times) const;
+  Literal Replicate(int64 times) const;
 
   // Creates a new Literal object with the shape specified as parameter.
   // The content of the literal values is the default value of the primitive
@@ -312,7 +300,7 @@
   // initialization, then reinitialization. Conside if a call to
   // absl::make_unique<Literal>(shape), followed by the call to
   // MutableLiteralBase::Populate can be used instead.
-  static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+  static Literal CreateFromShape(const Shape& shape);
 
  protected:
   // A data structure representing a subshape at a particular ShapeIndex within
@@ -539,8 +527,8 @@
 
  private:
   template <typename NativeT>
-  std::unique_ptr<Literal> SliceInternal(
-      const Shape& result_shape, absl::Span<const int64> start_indices) const;
+  Literal SliceInternal(const Shape& result_shape,
+                        absl::Span<const int64> start_indices) const;
 };
 
 // Abstract base class representing a mutable literal in XLA.
@@ -687,8 +675,7 @@
   static Literal MoveIntoTuple(absl::Span<Literal> elements);
 
   // Serialize from a proto.
-  static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
-      const LiteralProto& proto);
+  static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
 
  protected:
   // Returns the piece at the given ShapeIndex.
@@ -1137,15 +1124,14 @@
 }
 
 template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+Literal LiteralBase::Replicate(int64 times) const {
   DimensionVector bounds = {times};
   bounds.reserve(shape().dimensions_size() + 1);
   for (int64 bound : shape().dimensions()) {
     bounds.push_back(bound);
   }
-  auto literal = absl::make_unique<Literal>(
-      ShapeUtil::MakeShape(shape().element_type(), bounds));
-  int64 elements = ShapeUtil::ElementsIn(literal->shape());
+  Literal literal(ShapeUtil::MakeShape(shape().element_type(), bounds));
+  int64 elements = ShapeUtil::ElementsIn(literal.shape());
   if (elements == 0) {
     return literal;
   }
@@ -1157,7 +1143,7 @@
   bool done = false;
   while (!done) {
     const auto element = Get<NativeT>(input_indices);
-    literal->Set<NativeT>(output_indices, element);
+    literal.Set<NativeT>(output_indices, element);
 
     done = true;
     for (int n = 0; n < output_indices.size(); ++n) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 1a64594..7ad287c 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -92,48 +92,48 @@
   Layout layout_r3_dim0minor_;
   Layout layout_r4_dim0major_;
   Layout layout_r4_dim0minor_;
-  std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
-  std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+  Literal literal_r4_2x2x3x3_dim0major_;
+  Literal literal_r4_2x2x3x3_dim0minor_;
 };
 
 TEST_F(LiteralUtilTest, LiteralScalarToString) {
   auto true_lit = LiteralUtil::CreateR0<bool>(true);
-  EXPECT_EQ("true", true_lit->ToString());
+  EXPECT_EQ("true", true_lit.ToString());
 
   auto false_lit = LiteralUtil::CreateR0<bool>(false);
-  EXPECT_EQ("false", false_lit->ToString());
+  EXPECT_EQ("false", false_lit.ToString());
 
   auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
-  EXPECT_EQ("42", u32_lit->ToString());
+  EXPECT_EQ("42", u32_lit.ToString());
 
   auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
-  EXPECT_EQ("-999", s32_lit->ToString());
+  EXPECT_EQ("-999", s32_lit.ToString());
 
   auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
-  EXPECT_EQ("3.14", f32_lit->ToString());
+  EXPECT_EQ("3.14", f32_lit.ToString());
 
   auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
-  EXPECT_EQ("0.5", f16_lit->ToString());
+  EXPECT_EQ("0.5", f16_lit.ToString());
 
   auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
-  EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString());
+  EXPECT_EQ("(3.14, 2.78)", c64_lit.ToString());
 
   auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
-  EXPECT_EQ("0.5", bf16_lit->ToString());
+  EXPECT_EQ("0.5", bf16_lit.ToString());
 
   // 3.14 will be rounded to 3.14062 in bfloat16 format.
   auto bf16_lit_truncated =
       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
-  ASSERT_EQ("3.14062", bf16_lit_truncated->ToString());
+  ASSERT_EQ("3.14062", bf16_lit_truncated.ToString());
 
   auto bf16_lit_truncated2 =
       LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
-  EXPECT_EQ("9", bf16_lit_truncated2->ToString());
+  EXPECT_EQ("9", bf16_lit_truncated2.ToString());
 }
 
 TEST_F(LiteralUtilTest, LiteralVectorToString) {
   auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
-  EXPECT_EQ("{101}", pred_vec->ToString());
+  EXPECT_EQ("{101}", pred_vec.ToString());
 }
 
 TEST_F(LiteralUtilTest, R2ToString) {
@@ -143,7 +143,7 @@
   { 3, 4 },
   { 5, 6 }
 })";
-  EXPECT_EQ(expected, literal->ToString());
+  EXPECT_EQ(expected, literal.ToString());
 }
 
 TEST_F(LiteralUtilTest, R3ToString) {
@@ -157,13 +157,13 @@
 { { 5 },
   { 6 } }
 })";
-  EXPECT_EQ(expected, literal->ToString());
+  EXPECT_EQ(expected, literal.ToString());
 }
 
 TEST_F(LiteralUtilTest, TupleToString) {
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+  auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
   const string expected = R"((f32[], f32[2,2]) (
 1,
 f32[2,2] {
@@ -171,7 +171,7 @@
   { 3, 4 }
 }
 ))";
-  EXPECT_EQ(expected, tuple->ToString());
+  EXPECT_EQ(expected, tuple.ToString());
 }
 
 TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
@@ -187,8 +187,8 @@
   // clang-format on
 
   auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
-  EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
-  string result = literal->ToString();
+  EXPECT_THAT(literal.shape().dimensions(), ElementsAre(2, 3, 2));
+  string result = literal.ToString();
   const string expected = R"(f32[2,3,2] {
 { { 1, 2 },
   { 3, 4 },
@@ -220,10 +220,10 @@
   };
   std::vector<int64> expected_values = {8, 9, 7, 10};
 
-  EXPECT_EQ(literal->sparse_indices()->data(),
+  EXPECT_EQ(literal.sparse_indices()->data(),
             absl::Span<const int64>(expected_indices.data(),
                                     expected_indices.num_elements()));
-  EXPECT_EQ(literal->data<int64>(), absl::Span<const int64>(expected_values));
+  EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
 }
 
 TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
@@ -234,8 +234,8 @@
     {2001, 2002},
   }, /*projection_p=*/1, /*projection_z=*/2);
   // clang-format on
-  EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2));
-  string result = literal->ToString();
+  EXPECT_THAT(literal.shape().dimensions(), ElementsAre(1, 2, 3, 2));
+  string result = literal.ToString();
   const string expected = R"(f32[1,2,3,2] {
   {  /*i0=0*/
     {  /*i1=0*/
@@ -254,9 +254,9 @@
 }
 
 TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
-  EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(),
+  EXPECT_THAT(literal_r4_2x2x3x3_dim0major_.shape().dimensions(),
               ElementsAre(2, 2, 3, 3));
-  string result = literal_r4_2x2x3x3_dim0major_->ToString();
+  string result = literal_r4_2x2x3x3_dim0major_.ToString();
   const string expected = R"(f32[2,2,3,3] {
   {  /*i0=0*/
     {  /*i1=0*/
@@ -294,7 +294,7 @@
   });
   // clang-format on
   std::vector<std::tuple<int64, int64, string>> seen;
-  literal->EachCellAsString(
+  literal.EachCellAsString(
       [&seen](absl::Span<const int64> indices, const string& value) {
         seen.emplace_back(indices[0], indices[1], value);
       });
@@ -310,14 +310,14 @@
   auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
   auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
 
-  EXPECT_EQ(*f32_42, *f32_42);
-  EXPECT_EQ(*f32_42, *f32_42_clone);
+  EXPECT_EQ(f32_42, f32_42);
+  EXPECT_EQ(f32_42, f32_42_clone);
 
   auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
-  EXPECT_NE(*f32_42, *f32_123);
+  EXPECT_NE(f32_42, f32_123);
 
   auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
-  EXPECT_NE(*f32_42, *f64_42);
+  EXPECT_NE(f32_42, f64_42);
 }
 
 TEST_F(LiteralUtilTest, NonScalarEquality) {
@@ -330,12 +330,12 @@
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   Literal nil(ShapeUtil::MakeNil());
 
-  EXPECT_EQ(*matrix, *matrix);
-  EXPECT_EQ(*matrix, *matrix_clone);
-  EXPECT_NE(*matrix, *matrix_different);
-  EXPECT_NE(*matrix, *vector_literal);
-  EXPECT_NE(*matrix, *scalar);
-  EXPECT_NE(*matrix, nil);
+  EXPECT_EQ(matrix, matrix);
+  EXPECT_EQ(matrix, matrix_clone);
+  EXPECT_NE(matrix, matrix_different);
+  EXPECT_NE(matrix, vector_literal);
+  EXPECT_NE(matrix, scalar);
+  EXPECT_NE(matrix, nil);
   EXPECT_EQ(nil, nil);
 }
 
@@ -344,57 +344,54 @@
   auto token1 = LiteralUtil::CreateToken();
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
 
-  EXPECT_EQ(*token0, *token1);
-  EXPECT_NE(*token0, *scalar);
+  EXPECT_EQ(token0, token1);
+  EXPECT_NE(token0, scalar);
 
-  EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
-            *LiteralUtil::MakeTuple({token0.get()}));
-  EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
-            *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
-  EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
-            *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+  EXPECT_EQ(LiteralUtil::MakeTuple({&token0}),
+            LiteralUtil::MakeTuple({&token0}));
+  EXPECT_EQ(LiteralUtil::MakeTuple({&token0, &scalar}),
+            LiteralUtil::MakeTuple({&token1, &scalar}));
+  EXPECT_NE(LiteralUtil::MakeTuple({&token0, &scalar}),
+            LiteralUtil::MakeTuple({&scalar, &token1}));
 }
 
 TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
   // Test equality with literals which have different layouts.
-  auto colmajor = absl::make_unique<Literal>(
-      ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
-  colmajor->Set<float>({0, 0}, 1.0);
-  colmajor->Set<float>({0, 1}, 2.0);
-  colmajor->Set<float>({1, 0}, 3.0);
-  colmajor->Set<float>({1, 1}, 4.0);
+  Literal colmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}));
+  colmajor.Set<float>({0, 0}, 1.0);
+  colmajor.Set<float>({0, 1}, 2.0);
+  colmajor.Set<float>({1, 0}, 3.0);
+  colmajor.Set<float>({1, 1}, 4.0);
 
-  auto rowmajor = absl::make_unique<Literal>(
-      ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
-  rowmajor->Set<float>({0, 0}, 1.0);
-  rowmajor->Set<float>({0, 1}, 2.0);
-  rowmajor->Set<float>({1, 0}, 3.0);
-  rowmajor->Set<float>({1, 1}, 4.0);
+  Literal rowmajor(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}));
+  rowmajor.Set<float>({0, 0}, 1.0);
+  rowmajor.Set<float>({0, 1}, 2.0);
+  rowmajor.Set<float>({1, 0}, 3.0);
+  rowmajor.Set<float>({1, 1}, 4.0);
 
-  EXPECT_EQ(*rowmajor, *colmajor);
+  EXPECT_EQ(rowmajor, colmajor);
 }
 
 TEST_F(LiteralUtilTest, TupleEquality) {
   // Test equality with tuples.
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+  auto tuple1 = LiteralUtil::MakeTuple({&scalar, &matrix});
 
   // Tuple with the same elements. One element is shared with the original
   // tuple, the other is a clone of the element in the original tuple.
   auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
-  auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
-  EXPECT_EQ(*tuple1, *tuple2);
+  auto tuple2 = LiteralUtil::MakeTuple({&scalar_clone, &matrix});
+  EXPECT_EQ(tuple1, tuple2);
 
   // Tuple with elements reversed.
-  auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
-  EXPECT_NE(*tuple1, *reversed_tuple);
+  auto reversed_tuple = LiteralUtil::MakeTuple({&matrix, &scalar});
+  EXPECT_NE(tuple1, reversed_tuple);
 
   // Tuple with different value.
   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
-  auto different_tuple =
-      LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
-  EXPECT_NE(*tuple1, *different_tuple);
+  auto different_tuple = LiteralUtil::MakeTuple({&scalar_42, &matrix});
+  EXPECT_NE(tuple1, different_tuple);
 }
 
 TEST_F(LiteralUtilTest, C64Equality) {
@@ -405,162 +402,161 @@
   // tuple, the other is a clone of the element in the original tuple.
   auto vector_clone =
       LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
-  EXPECT_EQ(*vector, *vector_clone);
+  EXPECT_EQ(vector, vector_clone);
 
   auto vector_reversed =
       LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
-  EXPECT_NE(*vector, *vector_reversed);
+  EXPECT_NE(vector, vector_reversed);
 }
 
 TEST_F(LiteralUtilTest, IsAllTuple) {
   auto element1 = LiteralUtil::CreateR0<float>(0.0);
   auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
-  auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+  auto tuple = LiteralUtil::MakeTuple({&element1, &element1});
 
   // Tuples should always return false for IsAll.
-  EXPECT_FALSE(tuple->IsAll(0));
-  EXPECT_FALSE(tuple->IsAll(1));
+  EXPECT_FALSE(tuple.IsAll(0));
+  EXPECT_FALSE(tuple.IsAll(1));
 }
 
 // Verifies that CreateFromShape works for tuples.
 TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
   auto scalar = LiteralUtil::CreateR0<float>(0.0);
   auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
-  auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+  auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
 
-  auto x = Literal::CreateFromShape(tuple->shape());
-  EXPECT_EQ(*tuple, *x);
+  auto x = Literal::CreateFromShape(tuple.shape());
+  EXPECT_EQ(tuple, x);
 }
 
 TEST_F(LiteralUtilTest, IsAll) {
-  EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
-  EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
+  EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false).IsAll(0));
+  EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true).IsAll(1));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(1));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAll(2));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(2));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true).IsAll(-1));
 
   // We shouldn't reinterpret int8_min as an unsigned type and then decide that
   // it is equal to 255.
   auto int8_min = std::numeric_limits<int8>::min();
-  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
+  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255).IsAll(int8_min));
 
-  EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
-  EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
+  EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0).IsAll(42));
+  EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001).IsAll(42));
 
-  EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
-  EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+  EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100}).IsAll(100));
+  EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001}).IsAll(100));
 
-  EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+  EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}).IsAll(8));
 
   half h8(8.0f);
   half h9(9.0f);
-  EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+  EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}}).IsAll(8));
 
   bfloat16 b8(8.0f);
   bfloat16 b9(9.0f);
 
-  EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
-  EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+  EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}}).IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}}).IsAll(8));
 
   // 9.001 will be truncated to 9.0
   bfloat16 b91(9.001f);
   bfloat16 b90(9.00f);
-  EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+  EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}}).IsAll(9.0));
 
   complex64 c8_9 = {8, 9};
-  EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+  EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAll(8));
 
   auto uint64_max = std::numeric_limits<uint64>::max();
   EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
                    {{uint64_max, uint64_max}, {uint64_max, uint64_max}})
-                   ->IsAll(-1));
+                   .IsAll(-1));
 }
 
 TEST_F(LiteralUtilTest, IsAllFloat) {
   // IsAllFloat always returns false when the literal is not floating-point.
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllFloat(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllFloat(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllFloat(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllFloat(0));
 
-  EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
-  EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
-  EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
-  EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
+  EXPECT_TRUE(LiteralUtil::CreateR0<float>(0).IsAllFloat(0));
+  EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5).IsAllFloat(.5));
+  EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.5));
+  EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5).IsAllFloat(-.49));
   EXPECT_FALSE(
-      LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+      LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
   EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
-                  ->IsAllFloat(.5));
+                  .IsAllFloat(.5));
 
-  EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
-  EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
-  EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
-  EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
+  EXPECT_TRUE(LiteralUtil::CreateR0<double>(0).IsAllFloat(0));
+  EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5).IsAllFloat(.5));
+  EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.5));
+  EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5).IsAllFloat(-.49));
   EXPECT_FALSE(
-      LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+      LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}}).IsAllFloat(0));
 }
 
 TEST_F(LiteralUtilTest, IsAllComplex) {
   // IsAllComplex always returns false when the literal is not complex.
-  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
-  EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false).IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0).IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0).IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<int>(0).IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<float>(0).IsAllComplex(0));
+  EXPECT_FALSE(LiteralUtil::CreateR0<double>(0).IsAllComplex(0));
 
   complex64 c8_9 = {8, 9};
   complex64 c7_9 = {7, 9};
   EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
-                  ->IsAllComplex({8.0f, 9.0f}));
+                  .IsAllComplex({8.0f, 9.0f}));
   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
-                   ->IsAllComplex({8.0f, 9.0f}));
+                   .IsAllComplex({8.0f, 9.0f}));
   EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
-                   ->IsAllComplex({8.0f, 9.0f}));
+                   .IsAllComplex({8.0f, 9.0f}));
 }
 
 TEST_F(LiteralUtilTest, IsAllFirst) {
   // IsAllComplex always returns false when the literal is not complex.
-  EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
-  EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
-  EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
-  EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
-  EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
-  EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
-  EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
-  EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
-  EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true}).IsAllFirst());
+  EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false}).IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2}).IsAllFirst());
+  EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5}).IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2}).IsAllFirst());
+  EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5}).IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2}).IsAllFirst());
+  EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5}).IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2}).IsAllFirst());
 
   complex64 c8_9 = {8, 9};
   complex64 c7_9 = {7, 9};
-  EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
-  EXPECT_FALSE(
-      LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+  EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}}).IsAllFirst());
+  EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}}).IsAllFirst());
 }
 
 TEST_F(LiteralUtilTest, IsZero) {
   auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
   auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
-  EXPECT_TRUE(scalar_zero->IsZero({}));
-  EXPECT_FALSE(scalar_one->IsZero({}));
+  EXPECT_TRUE(scalar_zero.IsZero({}));
+  EXPECT_FALSE(scalar_one.IsZero({}));
 
   auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
-  EXPECT_FALSE(array->IsZero({0, 1}));
-  EXPECT_TRUE(array->IsZero({0, 2}));
-  EXPECT_TRUE(array->IsZero({1, 1}));
-  EXPECT_FALSE(array->IsZero({1, 2}));
+  EXPECT_FALSE(array.IsZero({0, 1}));
+  EXPECT_TRUE(array.IsZero({0, 2}));
+  EXPECT_TRUE(array.IsZero({1, 1}));
+  EXPECT_FALSE(array.IsZero({1, 2}));
 
   auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
   auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
-  EXPECT_TRUE(complex_zero->IsZero({}));
-  EXPECT_FALSE(complex_nonzero->IsZero({}));
+  EXPECT_TRUE(complex_zero.IsZero({}));
+  EXPECT_FALSE(complex_nonzero.IsZero({}));
 }
 
 template <typename T>
@@ -576,19 +572,19 @@
   const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
   const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
 
-  auto data01 = data->Relayout(layout01);
-  EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
-  EXPECT_EQ(*data, *data01);
+  auto data01 = data.Relayout(layout01);
+  EXPECT_TRUE(LayoutUtil::Equal(data01.shape().layout(), layout01));
+  EXPECT_EQ(data, data01);
 
-  auto data10 = data->Relayout(layout10);
-  EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
-  EXPECT_EQ(*data, *data10);
+  auto data10 = data.Relayout(layout10);
+  EXPECT_TRUE(LayoutUtil::Equal(data10.shape().layout(), layout10));
+  EXPECT_EQ(data, data10);
 }
 
 TEST_F(LiteralUtilTest, ReshapeR0) {
   auto original = LiteralUtil::CreateR0<float>(1.7f);
-  auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
-  EXPECT_EQ(*original, *reshape);
+  auto reshape = original.Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
+  EXPECT_EQ(original, reshape);
 }
 
 TEST_F(LiteralUtilTest, ReshapeR4) {
@@ -606,9 +602,9 @@
     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
   }, layout_r3_dim0major_);
   // clang-format on
-  auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+  auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
 
-  EXPECT_EQ(*expected, *reshape);
+  EXPECT_EQ(expected, reshape);
 }
 
 TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
@@ -626,15 +622,15 @@
     {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
   }, layout_r3_dim0major_);
   // clang-format on
-  auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie();
+  auto reshape = original.Reshape({3, 4, 2}).ConsumeValueOrDie();
 
-  EXPECT_EQ(*expected, *reshape);
+  EXPECT_EQ(expected, reshape);
 }
 
 TEST_F(LiteralUtilTest, TransposeR0) {
   auto original = LiteralUtil::CreateR0<float>(1.7f);
-  auto reshape = original->Transpose(/*permutation=*/{});
-  EXPECT_EQ(*original, *reshape);
+  auto reshape = original.Transpose(/*permutation=*/{});
+  EXPECT_EQ(original, reshape);
 }
 
 TEST_F(LiteralUtilTest, TransposeR4) {
@@ -646,10 +642,10 @@
      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   }});
   // clang-format on
-  auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1});
+  auto reshape = original.Transpose(/*permutation=*/{2, 3, 0, 1});
 
-  reshape->EachCell<float>([&](absl::Span<const int64> indices, float value) {
-    EXPECT_EQ(value, original->Get<float>(
+  reshape.EachCell<float>([&](absl::Span<const int64> indices, float value) {
+    EXPECT_EQ(value, original.Get<float>(
                          {indices[2], indices[3], indices[0], indices[1]}));
   });
 }
@@ -658,35 +654,35 @@
   // Tests that using Relayout on an array is equivalent to creating it in the
   // target layout in the first place.
   auto dim0minor_relaid_to_dim0major =
-      literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_);
-  EXPECT_EQ(*literal_r4_2x2x3x3_dim0major_, *dim0minor_relaid_to_dim0major);
+      literal_r4_2x2x3x3_dim0minor_.Relayout(layout_r4_dim0major_);
+  EXPECT_EQ(literal_r4_2x2x3x3_dim0major_, dim0minor_relaid_to_dim0major);
 
   auto dim0major_relaid_to_dim0minor =
-      literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_);
-  EXPECT_EQ(*literal_r4_2x2x3x3_dim0minor_, *dim0major_relaid_to_dim0minor);
+      literal_r4_2x2x3x3_dim0major_.Relayout(layout_r4_dim0minor_);
+  EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
 }
 
 TEST_F(LiteralUtilTest, TestR2LinearLayout) {
   // Test expected memory layout of R2 dim0-minor (column-major) literal.
   auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
-  EXPECT_EQ(mat_dim0minor->element_count(), 6);
-  EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
+  EXPECT_EQ(mat_dim0minor.element_count(), 6);
+  EXPECT_THAT(mat_dim0minor.data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
 
   // Test expected memory layout when using Relayout to row major.
-  auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_);
-  EXPECT_THAT(relaid_mat_to_dim0major->data<int32>(),
+  auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
+  EXPECT_THAT(relaid_mat_to_dim0major.data<int32>(),
               ElementsAre(1, 2, 3, 4, 5, 6));
 
   // Test expected memory layout of R2 created with dim0-major (row-major).
   auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
       {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
-  EXPECT_EQ(mat_dim0major->element_count(), 6);
-  EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
+  EXPECT_EQ(mat_dim0major.element_count(), 6);
+  EXPECT_THAT(mat_dim0major.data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
 
   // Test expected memory layout when using Relayout to column major.
-  auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_);
-  EXPECT_THAT(relaid_mat_to_dim0minor->data<int32>(),
+  auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
+  EXPECT_THAT(relaid_mat_to_dim0minor.data<int32>(),
               ElementsAre(1, 4, 2, 5, 3, 6));
 }
 
@@ -707,77 +703,77 @@
   auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
       arr3d, layout_r3_dim0minor_);
 
-  EXPECT_EQ(lit_dim0minor->element_count(), 12);
+  EXPECT_EQ(lit_dim0minor.element_count(), 12);
   std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
-  EXPECT_THAT(lit_dim0minor->data<int32>(),
+  EXPECT_THAT(lit_dim0minor.data<int32>(),
               testing::ElementsAreArray(expected_dim0minor));
 
   // Test expected memory layout when using Relayout to row major.
-  auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_);
+  auto relaid_lit_to_dim0major = lit_dim0minor.Relayout(layout_r3_dim0major_);
   std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
-  EXPECT_THAT(relaid_lit_to_dim0major->data<int32>(),
+  EXPECT_THAT(relaid_lit_to_dim0major.data<int32>(),
               testing::ElementsAreArray(expected_dim0major));
 
   // Test expected memory layout of R3 created with dim0-major (row-major).
   auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
       arr3d, layout_r3_dim0major_);
-  EXPECT_EQ(lit_dim0major->element_count(), 12);
-  EXPECT_THAT(lit_dim0major->data<int32>(),
+  EXPECT_EQ(lit_dim0major.element_count(), 12);
+  EXPECT_THAT(lit_dim0major.data<int32>(),
               testing::ElementsAreArray(expected_dim0major));
 
   // Test expected memory layout when using Relayout to column major.
-  auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_);
-  EXPECT_THAT(relaid_lit_to_dim0minor->data<int32>(),
+  auto relaid_lit_to_dim0minor = lit_dim0major.Relayout(layout_r3_dim0minor_);
+  EXPECT_THAT(relaid_lit_to_dim0minor.data<int32>(),
               testing::ElementsAreArray(expected_dim0minor));
 }
 
 TEST_F(LiteralUtilTest, SliceR0S32) {
   auto input = LiteralUtil::CreateR0<int32>(1);
-  auto result = input->Slice({}, {});
-  EXPECT_EQ(*input, *result);
+  auto result = input.Slice({}, {});
+  EXPECT_EQ(input, result);
 }
 
 TEST_F(LiteralUtilTest, SliceR1F32) {
   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
-  auto result = input->Slice({3}, {4});
+  auto result = input.Slice({3}, {4});
   auto expected = LiteralUtil::CreateR1<float>({4.0});
-  EXPECT_EQ(*expected, *result);
+  EXPECT_EQ(expected, result);
 }
 
 TEST_F(LiteralUtilTest, SliceR2U32) {
   auto input_3x4 = LiteralUtil::CreateR2<uint32>(
       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
-  auto result = input_3x4->Slice({0, 2}, {2, 4});
+  auto result = input_3x4.Slice({0, 2}, {2, 4});
   auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
-  EXPECT_EQ(*expected, *result);
+  EXPECT_EQ(expected, result);
 }
 
 TEST_F(LiteralUtilTest, SliceR3U32Full) {
   auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
-  auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
-  EXPECT_EQ(*input_2x3x2, *result);
+  auto result = input_2x3x2.Slice({0, 0, 0}, {2, 3, 2});
+  EXPECT_EQ(input_2x3x2, result);
 }
 
 TEST_F(LiteralUtilTest, PopulateR1S64) {
   Literal output(ShapeUtil::MakeShape(S64, {1}));
   output.PopulateR1<int64>({77});
   auto expected = LiteralUtil::CreateR1<int64>({77});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR1U64) {
   Literal output(ShapeUtil::MakeShape(U64, {2}));
   output.PopulateR1<uint64>({{77, 88}});
   auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR1C64) {
   Literal output(ShapeUtil::MakeShape(C64, {1}));
   output.PopulateR1<complex64>({{77, 88}});
   auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateR2C64) {
@@ -785,7 +781,7 @@
   output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
   auto expected =
       LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
@@ -793,7 +789,7 @@
   bfloat16 h(0.25f);
   output.PopulateWithValue<bfloat16>(h);
   auto expected = LiteralUtil::CreateR0<bfloat16>(h);
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
@@ -801,7 +797,7 @@
   bfloat16 h(0.5f);
   output.PopulateWithValue<bfloat16>(h);
   auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
@@ -809,28 +805,28 @@
   bfloat16 h(2.0f);
   output.PopulateWithValue<bfloat16>(h);
   auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
   Literal output(ShapeUtil::MakeShape(F32, {}));
   output.PopulateWithValue<float>(2.5f);
   auto expected = LiteralUtil::CreateR0<float>(2.5f);
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
   Literal output(ShapeUtil::MakeShape(S64, {3}));
   output.PopulateWithValue<int64>(-7);
   auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
   Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
   output.PopulateWithValue<uint64>(42);
   auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
@@ -838,7 +834,7 @@
   output.PopulateWithValue<complex64>({4, 2});
   auto expected =
       LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
@@ -846,7 +842,7 @@
   half h(0.25f);
   output.PopulateWithValue<half>(h);
   auto expected = LiteralUtil::CreateR0<half>(h);
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
@@ -854,7 +850,7 @@
   half h(0.5f);
   output.PopulateWithValue<half>(h);
   auto expected = LiteralUtil::CreateR1<half>({h, h, h});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
@@ -862,18 +858,18 @@
   half h(2.0f);
   output.PopulateWithValue<half>(h);
   auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
-  EXPECT_EQ(output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, ReplicateR2U32) {
   auto input = LiteralUtil::CreateR2<uint32>(
       {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
-  auto output = input->Replicate<uint32>(3);
+  auto output = input.Replicate<uint32>(3);
   auto expected = LiteralUtil::CreateR3<uint32>(
       {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
        {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
-  EXPECT_EQ(*output, *expected);
+  EXPECT_EQ(output, expected);
 }
 
 TEST_F(LiteralUtilTest, CopySliceFrom) {
@@ -889,17 +885,17 @@
     const int64 step[] = {1, 1, 1, 1};
     uint32 seqnr = 0;
     auto init_proc = [&](absl::Span<const int64> indexes) {
-      source->Set(indexes, ++seqnr);
+      source.Set(indexes, ++seqnr);
       return true;
     };
-    ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step,
+    ShapeUtil::ForEachIndex(source.shape(), zero_base, dimensions, step,
                             init_proc);
 
     auto blank = Literal::CreateFromShape(shape);
     const int64 src_base[] = {3, 1, 5, 7};
     const int64 dest_base[] = {6, 4, 12, 2};
     const int64 copy_size[] = {7, 8, 11, 9};
-    TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size));
+    TF_EXPECT_OK(blank.CopySliceFrom(source, src_base, dest_base, copy_size));
 
     std::vector<int64> source_indexes(TF_ARRAYSIZE(dimensions), 0);
     std::vector<int64> blank_indexes(TF_ARRAYSIZE(dimensions), 0);
@@ -911,12 +907,12 @@
       std::copy(indexes.begin(), indexes.end(), blank_indexes.begin());
       std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base,
                      blank_indexes.begin(), std::plus<int64>());
-      auto bval = blank->Get<uint32>(blank_indexes);
-      matched = (bval != 0 && bval == source->Get<uint32>(source_indexes));
+      auto bval = blank.Get<uint32>(blank_indexes);
+      matched = (bval != 0 && bval == source.Get<uint32>(source_indexes));
       return matched;
     };
 
-    ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step,
+    ShapeUtil::ForEachIndex(source.shape(), zero_base, copy_size, step,
                             check_proc);
     EXPECT_TRUE(matched);
   }
@@ -925,14 +921,14 @@
 TEST_F(LiteralUtilTest, CopyFromScalars) {
   auto zero = LiteralUtil::CreateR0<uint32>(0);
   auto nine = LiteralUtil::CreateR0<uint32>(9);
-  TF_EXPECT_OK(zero->CopyFrom(*nine));
-  EXPECT_EQ(*zero, *nine);
+  TF_EXPECT_OK(zero.CopyFrom(nine));
+  EXPECT_EQ(zero, nine);
 
   auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
-  TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
-  EXPECT_EQ(zero->Get<uint32>({}), 17);
-  TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
-  EXPECT_EQ(vect->Get<uint32>({4}), 17);
+  TF_EXPECT_OK(zero.CopySliceFrom(vect, {5}, {}, {}));
+  EXPECT_EQ(zero.Get<uint32>({}), 17);
+  TF_EXPECT_OK(vect.CopySliceFrom(zero, {}, {4}, {}));
+  EXPECT_EQ(vect.Get<uint32>({4}), 17);
 }
 
 TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
@@ -945,17 +941,17 @@
     const auto empty = Literal::CreateFromShape(empty_r1_shape);
     auto nine = LiteralUtil::CreateR1<float>({9});
 
-    TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
-    EXPECT_EQ(*nine, *const_nine);
+    TF_EXPECT_OK(nine.CopySliceFrom(empty, {0}, {0}, {0}));
+    EXPECT_EQ(nine, const_nine);
   }
 
   {
     // Copy 0 element to destination with zero elements.
-    const auto empty = Literal::CreateFromShape(empty_r1_shape);
+    auto empty = Literal::CreateFromShape(empty_r1_shape);
     auto nine = LiteralUtil::CreateR1<float>({9});
 
-    TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
-    EXPECT_EQ(*empty, *const_empty);
+    TF_EXPECT_OK(empty.CopySliceFrom(nine, {0}, {0}, {0}));
+    EXPECT_EQ(empty, const_empty);
   }
 }
 
@@ -969,74 +965,75 @@
 TEST_F(LiteralUtilTest, CopyFromArrays) {
   auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
   auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
-  EXPECT_NE(*scalar_42, *scalar_123);
-  TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
-                                   /*src_shape_index=*/{}));
-  EXPECT_EQ(*scalar_42, *scalar_123);
-  EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
+  EXPECT_NE(scalar_42, scalar_123);
+  TF_ASSERT_OK(scalar_42.CopyFrom(scalar_123, /*dest_shape_index=*/{},
+                                  /*src_shape_index=*/{}));
+  EXPECT_EQ(scalar_42, scalar_123);
+  EXPECT_EQ(scalar_42.Get<float>({}), 123.0f);
 
   auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
-  EXPECT_NE(*matrix_1234, *matrix_5678);
-  EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
-  TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
-                                     /*src_shape_index=*/{}));
-  EXPECT_EQ(*matrix_1234, *matrix_5678);
-  EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 5.0f);
+  EXPECT_NE(matrix_1234, matrix_5678);
+  EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 1.0f);
+  TF_ASSERT_OK(matrix_1234.CopyFrom(matrix_5678, /*dest_shape_index=*/{},
+                                    /*src_shape_index=*/{}));
+  EXPECT_EQ(matrix_1234, matrix_5678);
+  EXPECT_EQ(matrix_1234.Get<float>({0, 0}), 5.0f);
 }
 
 TEST_F(LiteralUtilTest, CopyFromTuples) {
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   Literal nil_literal(ShapeUtil::MakeNil());
-  auto nested_tuple = LiteralUtil::MakeTuple(
-      {matrix.get(),
-       LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR0<int32>(42).get(),
-            LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
-           .get()});
+  Literal inner_elements[] = {LiteralUtil::CreateR0<int32>(42),
+                              LiteralUtil::CreateR1<double>({23.0, 44.0})};
+  Literal inner_tuple = LiteralUtil::MakeTuple(
+      {&inner_elements[0], &inner_elements[1], &nil_literal});
+  Literal nested_tuple = LiteralUtil::MakeTuple({&matrix, &inner_tuple});
   // Create a tuple the same shape as the inner tuple of nested_tuple but with
   // different values..
-  auto tuple = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<int32>(-5).get(),
-       LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
+  Literal int32_minus5 = LiteralUtil::CreateR0<int32>(-5);
+  Literal double_2_4 = LiteralUtil::CreateR1<double>({2.0, 4.0});
+  Literal tuple =
+      LiteralUtil::MakeTuple({&int32_minus5, &double_2_4, &nil_literal});
 
-  EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
-  EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
-  EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 23.0);
-  EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 44.0);
+  EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
+  EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), 42);
+  EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 23.0);
+  EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 44.0);
 
   // Overwrite the inner tuple element of nested_tuple with the contents of
   // 'tuple'.
-  TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
-                                      /*src_shape_index=*/{}));
+  TF_ASSERT_OK(nested_tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+                                     /*src_shape_index=*/{}));
 
   // The matrix element should be unchanged.
-  EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
+  EXPECT_EQ(matrix, LiteralSlice(nested_tuple, {0}));
 
   // The tuple element should have been copied from 'tuple'.
-  EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), -5);
-  EXPECT_EQ(nested_tuple->Get<double>({0}, {1, 1}), 2.0);
-  EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
+  EXPECT_EQ(nested_tuple.Get<int32>({}, {1, 0}), -5);
+  EXPECT_EQ(nested_tuple.Get<double>({0}, {1, 1}), 2.0);
+  EXPECT_EQ(nested_tuple.Get<double>({1}, {1, 1}), 4.0);
 }
 TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
-  auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
-                                       LiteralUtil::CreateR0<int32>(4).get()});
+  Literal elements[] = {LiteralUtil::CreateR0<int32>(-2),
+                        LiteralUtil::CreateR0<int32>(4)};
+  Literal tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
 
-  EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
-  EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
+  EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+  EXPECT_EQ(tuple.Get<int32>({}, {1}), 4);
 
   // Copy from one element to the other.
-  TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1},
-                               /*src_shape_index=*/{0}));
+  TF_ASSERT_OK(tuple.CopyFrom(tuple, /*dest_shape_index=*/{1},
+                              /*src_shape_index=*/{0}));
 
-  EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
-  EXPECT_EQ(tuple->Get<int32>({}, {1}), -2);
+  EXPECT_EQ(tuple.Get<int32>({}, {0}), -2);
+  EXPECT_EQ(tuple.Get<int32>({}, {1}), -2);
 }
 
 TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
-  Status status = matrix->CopyFrom(*vector);
+  Status status = matrix.CopyFrom(vector);
   ASSERT_FALSE(status.ok());
   EXPECT_THAT(status.error_message(),
               HasSubstr("Destination subshape incompatible"));
@@ -1046,9 +1043,8 @@
   // Verify that the internal data views are consistent and that they
   // are in little endian format
   // TODO - modify if we make the data format machine endianess dependent
-  auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
-  Literal* l1 = m1.get();
-  const char* d1 = reinterpret_cast<const char*>(l1->data<half>().data());
+  Literal m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2}));
+  const char* d1 = reinterpret_cast<const char*>(m1.data<half>().data());
   EXPECT_EQ(d1[0], 0);
   EXPECT_EQ(d1[1], 0);
   EXPECT_EQ(d1[2], 0);
@@ -1061,8 +1057,7 @@
   half h1(1.0f);
   half h2(2.0f);
   auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
-  Literal* l2 = m2.get();
-  const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
+  const char* d2 = reinterpret_cast<const char*>(m2.data<half>().data());
   EXPECT_EQ(d2[0], 0);
   EXPECT_EQ(d2[1], 0x3C);
   EXPECT_EQ(d2[2], 0);
@@ -1091,25 +1086,25 @@
     Shape shape = ShapeUtil::MakeShapeWithLayout(
         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
         data.layout);
-    auto literal = absl::make_unique<Literal>(shape);
+    Literal literal(shape);
     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
       // Offsets from linear index just to avoid R0 literals to be initialized
       // with zero.
-      return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+      return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
                                                            indexes) +
              17;
     };
-    TF_EXPECT_OK(literal->Populate<uint32>(generator));
+    TF_EXPECT_OK(literal.Populate<uint32>(generator));
 
     std::vector<int64> zero_base(data.dimensions.size(), 0);
     std::vector<int64> step(data.dimensions.size(), 1);
     bool matched = true;
     auto check_function = [&](absl::Span<const int64> indexes) {
-      auto value = literal->Get<uint32>(indexes);
+      auto value = literal.Get<uint32>(indexes);
       matched = matched && (value == generator(indexes));
       return matched;
     };
-    ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+    ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
                             check_function);
     EXPECT_TRUE(matched);
   }
@@ -1133,25 +1128,25 @@
     Shape shape = ShapeUtil::MakeShapeWithLayout(
         primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
         data.layout);
-    auto literal = absl::make_unique<Literal>(shape);
+    Literal literal(shape);
     auto generator = [&](absl::Span<const int64> indexes) -> uint32 {
       // Offsets from linear index just to avoid R0 literals to be initialized
       // with zero.
-      return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+      return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
                                                            indexes) +
              17;
     };
-    TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+    TF_EXPECT_OK(literal.PopulateParallel<uint32>(generator));
 
     std::vector<int64> zero_base(data.dimensions.size(), 0);
     std::vector<int64> step(data.dimensions.size(), 1);
     bool matched = true;
     auto check_function = [&](absl::Span<const int64> indexes) {
-      auto value = literal->Get<uint32>(indexes);
+      auto value = literal.Get<uint32>(indexes);
       matched = matched && (value == generator(indexes));
       return matched;
     };
-    ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+    ShapeUtil::ForEachIndex(literal.shape(), zero_base, data.dimensions, step,
                             check_function);
     EXPECT_TRUE(matched);
   }
@@ -1170,10 +1165,9 @@
      {{26, 27, 28, 29}, {30, 31, 32, 33}},
   }}, layout_r4_dim0major_);
   // clang-format on
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
-                          original->Convert(U32));
+  TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.Convert(U32));
 
-  EXPECT_EQ(*expected, *converted);
+  EXPECT_EQ(expected, converted);
 }
 
 TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
@@ -1245,69 +1239,65 @@
     {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
   }}, layout_r4_dim0major_);
   // clang-format on
-  std::unique_ptr<Literal> conv;
+  Literal conv;
 
-  conv = s8->Convert(U32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *u32);
+  conv = s8.Convert(U32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, u32);
 
-  conv = s8->Convert(S32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *s32);
+  conv = s8.Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, s32);
 
-  conv = s8->Convert(U64).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *u64);
+  conv = s8.Convert(U64).ConsumeValueOrDie();
+  EXPECT_EQ(conv, u64);
 
-  conv = s8->Convert(S64).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *s64);
+  conv = s8.Convert(S64).ConsumeValueOrDie();
+  EXPECT_EQ(conv, s64);
 
-  conv = s8->Convert(PRED).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *pred);
+  conv = s8.Convert(PRED).ConsumeValueOrDie();
+  EXPECT_EQ(conv, pred);
 
-  conv = bf16->Convert(S32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *s32);
+  conv = bf16.Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, s32);
 
-  conv = bf16->Convert(F32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f32);
+  conv = bf16.Convert(F32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f32);
 
-  conv = pred->Convert(S32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *int32_pred);
+  conv = pred.Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, int32_pred);
 
-  conv = f32->Convert(S32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *s32);
+  conv = f32.Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, s32);
 
-  conv = f64->Convert(S32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *s32);
+  conv = f64.Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, s32);
 
-  conv = s32->Convert(F32).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f32);
+  conv = s32.Convert(F32).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f32);
 
-  conv = f32->Convert(F16).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f16);
+  conv = f32.Convert(F16).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f16);
 
-  conv = f64->Convert(F16).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f16);
+  conv = f64.Convert(F16).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f16);
 
-  conv = s32->Convert(F16).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f16);
+  conv = s32.Convert(F16).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f16);
 
-  conv = u32->Convert(F16).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *f16);
+  conv = u32.Convert(F16).ConsumeValueOrDie();
+  EXPECT_EQ(conv, f16);
 
-  conv = s32->Convert(C64).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *c64);
+  conv = s32.Convert(C64).ConsumeValueOrDie();
+  EXPECT_EQ(conv, c64);
 
-  conv = f16->Convert(C64).ConsumeValueOrDie();
-  EXPECT_EQ(*conv, *c64);
+  conv = f16.Convert(C64).ConsumeValueOrDie();
+  EXPECT_EQ(conv, c64);
 
-  EXPECT_EQ(s32->Convert(TUPLE).status().code(),
+  EXPECT_EQ(s32.Convert(TUPLE).status().code(),
             tensorflow::error::UNIMPLEMENTED);
-  EXPECT_EQ(s32->Convert(S16).status().code(),
-            tensorflow::error::UNIMPLEMENTED);
-  EXPECT_EQ(s32->Convert(U16).status().code(),
-            tensorflow::error::UNIMPLEMENTED);
-  EXPECT_EQ(c64->Convert(F32).status().code(),
-            tensorflow::error::UNIMPLEMENTED);
-  EXPECT_EQ(c64->Convert(S32).status().code(),
-            tensorflow::error::UNIMPLEMENTED);
+  EXPECT_EQ(s32.Convert(S16).status().code(), tensorflow::error::UNIMPLEMENTED);
+  EXPECT_EQ(s32.Convert(U16).status().code(), tensorflow::error::UNIMPLEMENTED);
+  EXPECT_EQ(c64.Convert(F32).status().code(), tensorflow::error::UNIMPLEMENTED);
+  EXPECT_EQ(c64.Convert(S32).status().code(), tensorflow::error::UNIMPLEMENTED);
 }
 
 TEST_F(LiteralUtilTest, BitcastConvert) {
@@ -1317,13 +1307,12 @@
        tensorflow::bit_cast<uint32>(100.f), 0xbeef});
   auto expected = LiteralUtil::CreateR1<float>(
       {2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
-                          original->BitcastConvert(F32));
+  TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
 }
 
 TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
   auto literal = LiteralUtil::CreateR0<uint32>(1234);
-  Status status = literal->BitcastConvert(F64).status();
+  Status status = literal.BitcastConvert(F64).status();
   EXPECT_NE(Status::OK(), status);
   EXPECT_TRUE(
       absl::StrContains(status.error_message(), "bit widths are different"));
@@ -1341,11 +1330,10 @@
       p.add_preds((i % 2) == (len % 2));
     }
 
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
-                            Literal::CreateFromProto(p));
-    ASSERT_EQ(len, literal->data<bool>().size());
+    TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+    ASSERT_EQ(len, literal.data<bool>().size());
     int i = 0;
-    for (bool value : literal->data<bool>()) {
+    for (bool value : literal.data<bool>()) {
       EXPECT_EQ((i % 2) == (len % 2), value);
       ++i;
     }
@@ -1358,11 +1346,10 @@
   half h2(2.0f);
 
   auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
-  Literal* l = m.get();
-  EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
-  EXPECT_EQ(4, l->data<half>().size());
+  EXPECT_EQ(4, ShapeUtil::ElementsIn(m.shape()));
+  EXPECT_EQ(4, m.data<half>().size());
 
-  LiteralProto p = l->ToProto();
+  LiteralProto p = m.ToProto();
   EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
   EXPECT_EQ(8, p.f16s().size());
   const char* d = p.f16s().data();
@@ -1389,9 +1376,8 @@
   LayoutUtil::SetToDefaultLayout(p.mutable_shape());
   p.clear_f16s();
   p.set_f16s(half_vals, 8);
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> literal,
-                          Literal::CreateFromProto(p));
-  auto r = literal->data<half>();
+  TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
+  auto r = literal.data<half>();
   ASSERT_EQ(4, r.size());
   EXPECT_EQ(h1, r[0]);
   EXPECT_EQ(h2, r[1]);
@@ -1402,43 +1388,41 @@
 TEST_F(LiteralUtilTest, LiteralSliceTest) {
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
-  auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+  auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+  auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
   Literal nil(ShapeUtil::MakeNil());
 
-  EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
-  EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix);
-  EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple);
-  EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple);
+  EXPECT_EQ(LiteralSlice(scalar, {}), scalar);
+  EXPECT_EQ(LiteralSlice(matrix, {}), matrix);
+  EXPECT_EQ(LiteralSlice(tuple, {}), tuple);
+  EXPECT_EQ(LiteralSlice(nested_tuple, {}), nested_tuple);
   EXPECT_EQ(LiteralSlice(nil, {}), nil);
 
-  EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar);
-  EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix);
+  EXPECT_EQ(LiteralSlice(tuple, {0}), scalar);
+  EXPECT_EQ(LiteralSlice(tuple, {1}), matrix);
 
-  EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple);
-  EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar);
-  EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix);
-  EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar);
+  EXPECT_EQ(LiteralSlice(nested_tuple, {0}), tuple);
+  EXPECT_EQ(LiteralSlice(nested_tuple, {0, 0}), scalar);
+  EXPECT_EQ(LiteralSlice(nested_tuple, {0, 1}), matrix);
+  EXPECT_EQ(LiteralSlice(nested_tuple, {1}), scalar);
 }
 
 TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
-  auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+  auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+  auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
   // Verify that changing the underlying data beneath the view changes the
   // data of the view itself.
-  const auto nested_tuple_view = LiteralSlice(*nested_tuple);
-  EXPECT_EQ(
-      nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
-      1.0f);
+  const auto nested_tuple_view = LiteralSlice(nested_tuple);
+  EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+            1.0f);
   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
                                          /*shape_index=*/{0, 0}),
             1.0f);
-  nested_tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
-  EXPECT_EQ(
-      nested_tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
-      555.0f);
+  nested_tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f);
+  EXPECT_EQ(nested_tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0, 0}),
+            555.0f);
   EXPECT_EQ(nested_tuple_view.Get<float>(/*multi_index=*/{},
                                          /*shape_index=*/{0, 0}),
             555.0f);
@@ -1447,14 +1431,14 @@
 TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
   auto scalar = LiteralUtil::CreateR0<float>(1.0);
   auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
-  auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+  auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix});
+  auto nested_tuple = LiteralUtil::MakeTuple({&tuple, &scalar});
 
-  const auto nested_tuple_view = LiteralSlice(*nested_tuple);
+  const auto nested_tuple_view = LiteralSlice(nested_tuple);
   const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
   const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
   EXPECT_EQ(matrix_view,
-            *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+            LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
 }
 
 TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
@@ -1497,9 +1481,8 @@
 }
 
 TEST_F(LiteralUtilTest, LiteralMove) {
-  std::unique_ptr<Literal> matrix =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  Literal literal(std::move(*matrix));
+  Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal literal(std::move(matrix));
 
   EXPECT_TRUE(
       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1511,17 +1494,21 @@
 
 TEST_F(LiteralUtilTest, DecomposeTuple) {
   Literal nil_literal(ShapeUtil::MakeNil());
-  auto nested_tuple = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
-       LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR0<int32>(42).get(),
-            LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
-           .get(),
-       &nil_literal});
+  Literal inner_elements[] = {
+      LiteralUtil::CreateR0<int32>(42),
+      LiteralUtil::CreateR1<double>({23.0, 44.0}),
+  };
+  Literal tuple_elements[] = {
+      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}),
+      LiteralUtil::MakeTuple(
+          {&inner_elements[0], &inner_elements[1], &nil_literal}),
+  };
+  Literal nested_tuple = LiteralUtil::MakeTuple(
+      {&tuple_elements[0], &tuple_elements[1], &nil_literal});
 
-  EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape()));
-  std::vector<Literal> elements = nested_tuple->DecomposeTuple();
-  EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape()));
+  EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple.shape()));
+  std::vector<Literal> elements = nested_tuple.DecomposeTuple();
+  EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple.shape()));
 
   ASSERT_EQ(elements.size(), 3);
 
@@ -1552,13 +1539,13 @@
 
 TEST_F(LiteralUtilTest, MoveIntoTuple) {
   std::vector<Literal> elements;
-  elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
-  elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
-  elements.push_back(std::move(*LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<int32>(42).get(),
-       LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
-
-                                   ));
+  elements.push_back(LiteralUtil::CreateR0<float>(1.0));
+  elements.push_back(LiteralUtil::CreateR1<int32>({4, 8}));
+  std::vector<Literal> inner_elements;
+  inner_elements.push_back(LiteralUtil::CreateR0<int32>(42));
+  inner_elements.push_back(LiteralUtil::CreateR1<double>({23.0, 44.0}));
+  elements.push_back(
+      LiteralUtil::MakeTuple({&inner_elements[0], &inner_elements[1]}));
 
   Literal literal = Literal::MoveIntoTuple(absl::MakeSpan(elements));
   ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1586,9 +1573,8 @@
   Literal literal;
   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
 
-  std::unique_ptr<Literal> matrix =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  literal = std::move(*matrix);
+  Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  literal = std::move(matrix);
 
   EXPECT_TRUE(
       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape()));
@@ -1599,9 +1585,8 @@
 }
 
 TEST_F(LiteralUtilTest, LiteralSliceCopy) {
-  std::unique_ptr<Literal> matrix =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
-  const auto matrix_view = LiteralSlice(*matrix);
+  Literal matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  const auto matrix_view = LiteralSlice(matrix);
   LiteralSlice matrix_view_copy(matrix_view);
 
   EXPECT_EQ(matrix_view_copy.Get<float>({0, 0}), 1.0);
@@ -1611,45 +1596,43 @@
 }
 
 TEST_F(LiteralUtilTest, GetSetTuple) {
-  auto tuple = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(42.0).get(),
-       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
-  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
-  tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
-  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
+  Literal elements[] = {
+      LiteralUtil::CreateR0<float>(42.0),
+      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+  };
+  auto tuple = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
+  EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
+  tuple.Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
+  EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
 
-  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
-            3.0);
-  tuple->Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
-  EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
+  EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), 3.0);
+  tuple.Set<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0);
+  EXPECT_EQ(tuple.Get<float>(/*multi_index=*/{1, 0}, /*shape_index=*/{1}),
             -4.0);
 }
 
 TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
   // Literals constructed using CreateFromShape should be zero initialized.
-  std::unique_ptr<Literal> scalar_f32 =
-      Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
-  EXPECT_EQ(scalar_f32->Get<float>({}), 0.0);
-  EXPECT_TRUE(scalar_f32->IsAll(0));
+  Literal scalar_f32 = Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {}));
+  EXPECT_EQ(scalar_f32.Get<float>({}), 0.0);
+  EXPECT_TRUE(scalar_f32.IsAll(0));
 
-  std::unique_ptr<Literal> vector_s32 =
-      Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
-  EXPECT_EQ(vector_s32->Get<int32>({0}), 0);
-  EXPECT_EQ(vector_s32->Get<int32>({1}), 0);
-  EXPECT_EQ(vector_s32->Get<int32>({2}), 0);
-  EXPECT_TRUE(vector_s32->IsAll(0));
+  Literal vector_s32 = Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3}));
+  EXPECT_EQ(vector_s32.Get<int32>({0}), 0);
+  EXPECT_EQ(vector_s32.Get<int32>({1}), 0);
+  EXPECT_EQ(vector_s32.Get<int32>({2}), 0);
+  EXPECT_TRUE(vector_s32.IsAll(0));
 
-  std::unique_ptr<Literal> tuple =
-      Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
-          {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
-           ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
+  Literal tuple = Literal::CreateFromShape(ShapeUtil::MakeTupleShape(
+      {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}),
+       ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})}));
 
-  EXPECT_EQ(tuple->Get<double>({}, {0}), 0.0);
-  EXPECT_EQ(tuple->Get<bool>({0}, {1}), false);
-  EXPECT_EQ(tuple->Get<bool>({1}, {1}), false);
-  EXPECT_EQ(tuple->Get<uint64>({0, 0}, {2}), 0);
-  EXPECT_EQ(tuple->Get<uint64>({1, 0}, {2}), 0);
-  EXPECT_EQ(tuple->Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
+  EXPECT_EQ(tuple.Get<double>({}, {0}), 0.0);
+  EXPECT_EQ(tuple.Get<bool>({0}, {1}), false);
+  EXPECT_EQ(tuple.Get<bool>({1}, {1}), false);
+  EXPECT_EQ(tuple.Get<uint64>({0, 0}, {2}), 0);
+  EXPECT_EQ(tuple.Get<uint64>({1, 0}, {2}), 0);
+  EXPECT_EQ(tuple.Get<complex64>({}, {3}), complex64(0.0f, 0.0f));
 }
 
 TEST_F(LiteralUtilTest, ProtoRoundTrip) {
@@ -1657,6 +1640,7 @@
   auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
   auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
   auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+  auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
   auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
   auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
       {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
@@ -1665,25 +1649,27 @@
   auto matrix_pred =
       LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
   auto tuple = LiteralUtil::MakeTuple(
-      {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
+      {&one_f32, &vector_half, &matrix_pred, &matrix_pred});
   Literal nil_literal(ShapeUtil::MakeNil());
-  auto nested_tuple = LiteralUtil::MakeTuple(
-      {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
+  auto nested_tuple =
+      LiteralUtil::MakeTuple({&tuple, &vector_bfloat16, &tuple, &nil_literal});
 
   auto to_from_proto = [](const Literal& literal) -> Literal {
-    return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie());
+    return Literal::CreateFromProto(literal.ToProto()).ValueOrDie();
   };
 
-  EXPECT_EQ(*one_f32, to_from_proto(*one_f32));
-  EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64));
-  EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16));
-  EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred));
-  EXPECT_EQ(*tuple, to_from_proto(*tuple));
-  EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple));
+  EXPECT_EQ(one_f32, to_from_proto(one_f32));
+  EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
+  EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
+  EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
+  EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
+  EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
+  EXPECT_EQ(tuple, to_from_proto(tuple));
+  EXPECT_EQ(nested_tuple, to_from_proto(nested_tuple));
   EXPECT_EQ(nil_literal, to_from_proto(nil_literal));
 
-  EXPECT_NE(*one_f32, *two_f32);
-  EXPECT_NE(*one_f32, to_from_proto(*two_f32));
+  EXPECT_NE(one_f32, two_f32);
+  EXPECT_NE(one_f32, to_from_proto(two_f32));
 }
 
 TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
@@ -1802,11 +1788,11 @@
 TEST_F(LiteralUtilTest, SortSparseElements) {
   auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
                                                   SparseIndexArray(10, 3), {});
-  literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
-  literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
-  literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
-  literal->SortSparseElements();
-  EXPECT_EQ(literal->ToString(false),
+  literal.AppendSparseElement<float>({2, 3, 4}, 2.0);
+  literal.AppendSparseElement<float>({3, 4, 5}, 3.0);
+  literal.AppendSparseElement<float>({1, 2, 3}, 1.0);
+  literal.SortSparseElements();
+  EXPECT_EQ(literal.ToString(false),
             "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}");
 }
 
@@ -1816,57 +1802,54 @@
 
   EXPECT_EQ(
       LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
-          ->GetSparseElementAsString(1),
+          .GetSparseElementAsString(1),
       "false");
   EXPECT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
-                ->GetSparseElementAsString(1),
+                .GetSparseElementAsString(1),
             absl::StrCat(int64{2}));
   EXPECT_EQ(
       LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
-          ->GetSparseElementAsString(1),
+          .GetSparseElementAsString(1),
       absl::StrCat(double{2.0}));
   EXPECT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
                                             {half{1.0}, half{2.0}, half{3.0}})
-                ->GetSparseElementAsString(1),
+                .GetSparseElementAsString(1),
             absl::StrCat(static_cast<float>(half{2.0})));
   EXPECT_EQ(LiteralUtil::CreateSparse<complex64>(
                 dimensions, indices,
                 std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
-                ->GetSparseElementAsString(1),
+                .GetSparseElementAsString(1),
             absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
 }
 
 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+  Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> broadcasted_literal,
-      literal->Broadcast(
-          /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
-          /*dimensions=*/{0}));
-  EXPECT_EQ(*broadcasted_literal,
-            *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
+      Literal broadcasted_literal,
+      literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+                        /*dimensions=*/{0}));
+  EXPECT_EQ(broadcasted_literal,
+            LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
 }
 
 TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
+  Literal literal = LiteralUtil::CreateR1<int64>({1, 2});
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> broadcasted_literal,
-      literal->Broadcast(
-          /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
-          /*dimensions=*/{1}));
-  EXPECT_EQ(*broadcasted_literal,
-            *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
+      Literal broadcasted_literal,
+      literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+                        /*dimensions=*/{1}));
+  EXPECT_EQ(broadcasted_literal,
+            LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
 }
 
 TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
+  Literal literal = LiteralUtil::CreateR0<int32>(9);
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> broadcasted_literal,
-      literal->Broadcast(
-          /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
-          /*dimensions=*/{}));
-  EXPECT_EQ(*broadcasted_literal,
-            *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
+      Literal broadcasted_literal,
+      literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+                        /*dimensions=*/{}));
+  EXPECT_EQ(broadcasted_literal,
+            LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 613449c..0cb1ae3 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -45,7 +45,7 @@
 // Return a literal with all arrays of type FromNativeT converted to type
 // ToNativeT in the given literal.
 template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+Literal ConvertType(LiteralSlice literal) {
   // First construct shape of the result.
   Shape result_shape(literal.shape());
   ShapeUtil::ForEachMutableSubshape(
@@ -56,7 +56,7 @@
               primitive_util::NativeToPrimitiveType<ToNativeT>());
         }
       });
-  auto result = absl::make_unique<Literal>(result_shape);
+  Literal result(result_shape);
 
   // Then copy over the data from 'literal' converting FromNativeT values to
   // ToNativeT values as necessary.
@@ -67,14 +67,14 @@
           if (subshape.element_type() ==
               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
             auto src = literal.data<FromNativeT>(shape_index);
-            auto dest = result->data<ToNativeT>(shape_index);
+            auto dest = result.data<ToNativeT>(shape_index);
             for (int64 i = 0; i < src.size(); ++i) {
               dest[i] = static_cast<ToNativeT>(src[i]);
             }
           } else {
-            TF_CHECK_OK(result->CopyFrom(literal,
-                                         /*dest_shape_index=*/shape_index,
-                                         /*src_shape_index=*/shape_index));
+            TF_CHECK_OK(result.CopyFrom(literal,
+                                        /*dest_shape_index=*/shape_index,
+                                        /*src_shape_index=*/shape_index));
           }
         }
       });
@@ -83,53 +83,52 @@
 
 }  // namespace
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
+/* static */ Literal LiteralUtil::CreateFromDimensions(
     PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
   return Literal::CreateFromShape(
       ShapeUtil::MakeShape(primitive_type, dimensions));
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
+/* static */ Literal LiteralUtil::ConvertBF16ToF32(
     const LiteralSlice& bf16_literal) {
   return ConvertType<bfloat16, float>(bf16_literal);
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
+/* static */ Literal LiteralUtil::ConvertF32ToBF16(
     const LiteralSlice& f32_literal) {
   return ConvertType<float, bfloat16>(f32_literal);
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
-  return absl::make_unique<Literal>(ShapeUtil::MakeTokenShape());
+/* static */ Literal LiteralUtil::CreateToken() {
+  return Literal(ShapeUtil::MakeTokenShape());
 }
 
 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return std::move(*LiteralUtil::CreateR0<uint8>(0));
+      return LiteralUtil::CreateR0<uint8>(0);
     case U32:
-      return std::move(*LiteralUtil::CreateR0<uint32>(0));
+      return LiteralUtil::CreateR0<uint32>(0);
     case U64:
-      return std::move(*LiteralUtil::CreateR0<uint64>(0));
+      return LiteralUtil::CreateR0<uint64>(0);
     case S8:
-      return std::move(*LiteralUtil::CreateR0<int8>(0));
+      return LiteralUtil::CreateR0<int8>(0);
     case S32:
-      return std::move(*LiteralUtil::CreateR0<int32>(0));
+      return LiteralUtil::CreateR0<int32>(0);
     case S64:
-      return std::move(*LiteralUtil::CreateR0<int64>(0));
+      return LiteralUtil::CreateR0<int64>(0);
     case F16:
-      return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
+      return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
     case BF16:
-      return std::move(
-          *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+      return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
     case F32:
-      return std::move(*LiteralUtil::CreateR0<float>(0));
+      return LiteralUtil::CreateR0<float>(0);
     case F64:
-      return std::move(*LiteralUtil::CreateR0<double>(0));
+      return LiteralUtil::CreateR0<double>(0);
     case C64:
-      return std::move(*LiteralUtil::CreateR0<complex64>(0));
+      return LiteralUtil::CreateR0<complex64>(0);
     case PRED:
-      return std::move(*LiteralUtil::CreateR0<bool>(false));
+      return LiteralUtil::CreateR0<bool>(false);
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -145,30 +144,29 @@
 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return std::move(*LiteralUtil::CreateR0<uint8>(1));
+      return LiteralUtil::CreateR0<uint8>(1);
     case U32:
-      return std::move(*LiteralUtil::CreateR0<uint32>(1));
+      return LiteralUtil::CreateR0<uint32>(1);
     case U64:
-      return std::move(*LiteralUtil::CreateR0<uint64>(1));
+      return LiteralUtil::CreateR0<uint64>(1);
     case S8:
-      return std::move(*LiteralUtil::CreateR0<int8>(1));
+      return LiteralUtil::CreateR0<int8>(1);
     case S32:
-      return std::move(*LiteralUtil::CreateR0<int32>(1));
+      return LiteralUtil::CreateR0<int32>(1);
     case S64:
-      return std::move(*LiteralUtil::CreateR0<int64>(1));
+      return LiteralUtil::CreateR0<int64>(1);
     case F16:
-      return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
+      return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
     case BF16:
-      return std::move(
-          *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+      return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
     case F32:
-      return std::move(*LiteralUtil::CreateR0<float>(1));
+      return LiteralUtil::CreateR0<float>(1);
     case F64:
-      return std::move(*LiteralUtil::CreateR0<double>(1));
+      return LiteralUtil::CreateR0<double>(1);
     case C64:
-      return std::move(*LiteralUtil::CreateR0<complex64>(1));
+      return LiteralUtil::CreateR0<complex64>(1);
     case PRED:
-      return std::move(*LiteralUtil::CreateR0<bool>(true));
+      return LiteralUtil::CreateR0<bool>(true);
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -184,42 +182,36 @@
 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return std::move(
-          *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+      return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
     case U32:
-      return std::move(
-          *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+      return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
     case U64:
-      return std::move(
-          *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+      return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
     case S8:
-      return std::move(
-          *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
+      return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
     case S32:
-      return std::move(
-          *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
+      return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
     case S64:
-      return std::move(
-          *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
+      return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
     case F32:
-      return std::move(*LiteralUtil::CreateR0<float>(
-          -std::numeric_limits<float>::infinity()));
+      return LiteralUtil::CreateR0<float>(
+          -std::numeric_limits<float>::infinity());
     case F64:
-      return std::move(*LiteralUtil::CreateR0<double>(
-          -std::numeric_limits<double>::infinity()));
+      return LiteralUtil::CreateR0<double>(
+          -std::numeric_limits<double>::infinity());
     case C64:
       LOG(FATAL) << "C64 element type has no minimum value";
     case PRED:
-      return std::move(*LiteralUtil::CreateR0<bool>(false));
+      return LiteralUtil::CreateR0<bool>(false);
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case F16:
-      return std::move(*LiteralUtil::CreateR0<half>(
-          static_cast<half>(-std::numeric_limits<float>::infinity())));
+      return LiteralUtil::CreateR0<half>(
+          static_cast<half>(-std::numeric_limits<float>::infinity()));
     case BF16:
-      return std::move(*LiteralUtil::CreateR0<bfloat16>(
-          static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
+      return LiteralUtil::CreateR0<bfloat16>(
+          static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no minimum value";
     case OPAQUE:
@@ -232,40 +224,34 @@
 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
   switch (primitive_type) {
     case U8:
-      return std::move(
-          *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+      return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
     case U32:
-      return std::move(
-          *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+      return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
     case U64:
-      return std::move(
-          *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+      return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
     case S8:
-      return std::move(
-          *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
+      return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
     case S32:
-      return std::move(
-          *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
+      return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
     case S64:
-      return std::move(
-          *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
+      return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
     case F32:
-      return std::move(*LiteralUtil::CreateR0<float>(
-          std::numeric_limits<float>::infinity()));
+      return LiteralUtil::CreateR0<float>(
+          std::numeric_limits<float>::infinity());
     case F64:
-      return std::move(*LiteralUtil::CreateR0<double>(
-          std::numeric_limits<double>::infinity()));
+      return LiteralUtil::CreateR0<double>(
+          std::numeric_limits<double>::infinity());
     case PRED:
-      return std::move(*LiteralUtil::CreateR0<bool>(true));
+      return LiteralUtil::CreateR0<bool>(true);
     case S16:
     case U16:
       LOG(FATAL) << "u16/s16 literals not yet implemented";
     case F16:
-      return std::move(*LiteralUtil::CreateR0<half>(
-          static_cast<half>(std::numeric_limits<float>::infinity())));
+      return LiteralUtil::CreateR0<half>(
+          static_cast<half>(std::numeric_limits<float>::infinity()));
     case BF16:
-      return std::move(*LiteralUtil::CreateR0<bfloat16>(
-          static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
+      return LiteralUtil::CreateR0<bfloat16>(
+          static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no maximum value";
     case OPAQUE:
@@ -275,31 +261,29 @@
   }
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+/* static */ Literal LiteralUtil::CreateR1(
     const tensorflow::core::Bitmap& values) {
-  auto literal = absl::make_unique<Literal>(
+  Literal literal(
       ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
-  literal->PopulateR1(values);
+  literal.PopulateR1(values);
   return literal;
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
-    absl::string_view value) {
-  auto literal = absl::make_unique<Literal>(
-      ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
+/* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
+  Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
   for (int i = 0; i < value.size(); ++i) {
-    literal->Set<uint8>({i}, value[i]);
+    literal.Set<uint8>({i}, value[i]);
   }
   return literal;
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
-    float from, float to, int64 rows, int64 cols) {
+/* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
+                                                      int64 rows, int64 cols) {
   auto value = MakeLinspaceArray2D(from, to, rows, cols);
   return CreateR2FromArray2D(*value);
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
+/* static */ Literal LiteralUtil::ReshapeSlice(
     absl::Span<const int64> new_dimensions,
     absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
   int64 new_num_elements = 1;
@@ -309,13 +293,13 @@
   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
 
-  auto new_literal = absl::make_unique<Literal>(
+  Literal new_literal(
       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
 
   // Create a new shape with the given minor-to-major layout. This shape is used
   // solely for converting linear address to multi-dimensional addresses when
   // writing elements to the new literal.
-  Shape shape_with_layout = new_literal->shape();
+  Shape shape_with_layout = new_literal.shape();
   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
 
   // Copy data into new literal, element-by-element.
@@ -326,40 +310,40 @@
         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
     switch (literal.shape().element_type()) {
       case PRED:
-        new_literal->Set<bool>(to_multi_index,
-                               literal.Get<bool>(from_multi_index));
+        new_literal.Set<bool>(to_multi_index,
+                              literal.Get<bool>(from_multi_index));
         break;
       case U8:
-        new_literal->Set<uint8>(to_multi_index,
-                                literal.Get<uint8>(from_multi_index));
+        new_literal.Set<uint8>(to_multi_index,
+                               literal.Get<uint8>(from_multi_index));
         break;
       case U32:
-        new_literal->Set<uint32>(to_multi_index,
-                                 literal.Get<uint32>(from_multi_index));
+        new_literal.Set<uint32>(to_multi_index,
+                                literal.Get<uint32>(from_multi_index));
         break;
       case S32:
-        new_literal->Set<int32>(to_multi_index,
-                                literal.Get<int32>(from_multi_index));
+        new_literal.Set<int32>(to_multi_index,
+                               literal.Get<int32>(from_multi_index));
         break;
       case U64:
-        new_literal->Set<uint64>(to_multi_index,
-                                 literal.Get<uint64>(from_multi_index));
+        new_literal.Set<uint64>(to_multi_index,
+                                literal.Get<uint64>(from_multi_index));
         break;
       case S64:
-        new_literal->Set<int64>(to_multi_index,
-                                literal.Get<int64>(from_multi_index));
+        new_literal.Set<int64>(to_multi_index,
+                               literal.Get<int64>(from_multi_index));
         break;
       case F32:
-        new_literal->Set<float>(to_multi_index,
-                                literal.Get<float>(from_multi_index));
+        new_literal.Set<float>(to_multi_index,
+                               literal.Get<float>(from_multi_index));
         break;
       case F64:
-        new_literal->Set<double>(to_multi_index,
-                                 literal.Get<double>(from_multi_index));
+        new_literal.Set<double>(to_multi_index,
+                                literal.Get<double>(from_multi_index));
         break;
       case C64:
-        new_literal->Set<complex64>(to_multi_index,
-                                    literal.Get<complex64>(from_multi_index));
+        new_literal.Set<complex64>(to_multi_index,
+                                   literal.Get<complex64>(from_multi_index));
         break;
       default:
         LOG(FATAL) << "Unhandled primitive element type: "
@@ -376,97 +360,82 @@
   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
   switch (literal.shape().element_type()) {
     case PRED:
-      return std::move(
-          *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
+      return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
     // 8 bit types.
     case S8:
-      return std::move(
-          *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
+      return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
     case U8:
-      return std::move(
-          *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
+      return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
     // 16 bit types.
     case BF16:
-      return std::move(*LiteralUtil::CreateR0<bfloat16>(
-          literal.GetFirstElement<bfloat16>()));
+      return LiteralUtil::CreateR0<bfloat16>(
+          literal.GetFirstElement<bfloat16>());
     case F16:
-      return std::move(
-          *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
+      return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
     case S16:
-      return std::move(
-          *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
+      return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
     case U16:
-      return std::move(
-          *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
+      return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
     // 32 bit types.
     case F32:
-      return std::move(
-          *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
+      return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
     case S32:
-      return std::move(
-          *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
+      return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
     case U32:
-      return std::move(
-          *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
+      return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
     // 64 bit types.
     case C64:
-      return std::move(*LiteralUtil::CreateR0<complex64>(
-          literal.GetFirstElement<complex64>()));
+      return LiteralUtil::CreateR0<complex64>(
+          literal.GetFirstElement<complex64>());
     case F64:
-      return std::move(
-          *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
+      return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
     case S64:
-      return std::move(
-          *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
+      return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
     case U64:
-      return std::move(
-          *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
+      return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
     default:
       LOG(FATAL) << "Unhandled primitive type "
                  << literal.shape().element_type();
   }
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
+/* static */ Literal LiteralUtil::MakeTuple(
     absl::Span<const Literal* const> elements) {
   std::vector<Shape> element_shapes;
   for (const auto* element : elements) {
     element_shapes.push_back(element->shape());
   }
-  auto literal =
-      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
   for (int i = 0; i < elements.size(); ++i) {
-    TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
+    TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
   }
   return literal;
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
+/* static */ Literal LiteralUtil::MakeTupleFromSlices(
     absl::Span<const LiteralSlice> elements) {
   std::vector<Shape> element_shapes;
   for (const auto& element : elements) {
     element_shapes.push_back(element.shape());
   }
-  auto literal =
-      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
   for (int i = 0; i < elements.size(); ++i) {
-    TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i}));
+    TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
   }
   return literal;
 }
 
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
-    std::vector<std::unique_ptr<Literal>> elements) {
+/* static */ Literal LiteralUtil::MakeTupleOwned(
+    std::vector<Literal> elements) {
   std::vector<Shape> element_shapes;
   element_shapes.reserve(elements.size());
   for (const auto& element : elements) {
-    element_shapes.push_back(element->shape());
+    element_shapes.push_back(element.shape());
   }
-  auto literal =
-      absl::make_unique<Literal>(ShapeUtil::MakeTupleShape(element_shapes));
+  Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
   for (int64 i = 0; i < elements.size(); ++i) {
     TF_CHECK_OK(
-        literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i}));
+        literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
   }
   return literal;
 }
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2d6084a..2b18162 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -69,36 +69,34 @@
   // The variants not ending with WithLayout use the default XLA layout for the
   // literal's linear representation in memory.
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR0(NativeT value);
+  static Literal CreateR0(NativeT value);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR1(absl::Span<const NativeT> values);
-  static std::unique_ptr<Literal> CreateR1(
-      const tensorflow::core::Bitmap& values);
+  static Literal CreateR1(absl::Span<const NativeT> values);
+  static Literal CreateR1(const tensorflow::core::Bitmap& values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR2(
+  static Literal CreateR2(
       std::initializer_list<std::initializer_list<NativeT>> values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR2WithLayout(
+  static Literal CreateR2WithLayout(
       std::initializer_list<std::initializer_list<NativeT>> values,
       const Layout& layout);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR3(
-      std::initializer_list<
-          std::initializer_list<std::initializer_list<NativeT>>>
-          values);
+  static Literal CreateR3(std::initializer_list<
+                          std::initializer_list<std::initializer_list<NativeT>>>
+                              values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR3WithLayout(
+  static Literal CreateR3WithLayout(
       std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>
           values,
       const Layout& layout);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4(
+  static Literal CreateR4(
       std::initializer_list<std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>>
           values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4WithLayout(
+  static Literal CreateR4WithLayout(
       std::initializer_list<std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>>
           values,
@@ -139,9 +137,10 @@
   //     [9, 10, 11]: 4.0
   //
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateSparse(
-      absl::Span<const int64> dimensions, SparseIndexArray indices,
-      absl::Span<const NativeT> values, bool sort = true);
+  static Literal CreateSparse(absl::Span<const int64> dimensions,
+                              SparseIndexArray indices,
+                              absl::Span<const NativeT> values,
+                              bool sort = true);
 
   // Creates a scalar literal value zero of the given primitive type.
   static Literal Zero(PrimitiveType primitive_type);
@@ -155,130 +154,120 @@
   static Literal MaxValue(PrimitiveType primitive_type);
   // Creates a literal of the given shape where each element is `value`.
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateFullWithDescendingLayout(
+  static Literal CreateFullWithDescendingLayout(
       absl::Span<const int64> dimensions, NativeT value);
 
   // Creates a new literal from an Array type. The variants not ending with
   // WithLayout use the default XLA layout for the literal's linear
   // representation in memory.
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
+  static Literal CreateFromArray(const Array<NativeT>& values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateFromArrayWithLayout(
-      const Array<NativeT>& values, const Layout& layout);
+  static Literal CreateFromArrayWithLayout(const Array<NativeT>& values,
+                                           const Layout& layout);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR2FromArray2D(
-      const Array2D<NativeT>& values);
+  static Literal CreateR2FromArray2D(const Array2D<NativeT>& values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
-      const Array2D<NativeT>& values, const Layout& layout);
+  static Literal CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+                                               const Layout& layout);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR3FromArray3D(
-      const Array3D<NativeT>& values);
+  static Literal CreateR3FromArray3D(const Array3D<NativeT>& values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
-      const Array3D<NativeT>& values, const Layout& layout);
+  static Literal CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+                                               const Layout& layout);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4FromArray4D(
-      const Array4D<NativeT>& values);
+  static Literal CreateR4FromArray4D(const Array4D<NativeT>& values);
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
-      const Array4D<NativeT>& values, const Layout& layout);
+  static Literal CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+                                               const Layout& layout);
 
   // Creates a new vector of U8s literal value from a string.
-  static std::unique_ptr<Literal> CreateR1U8(absl::string_view value);
+  static Literal CreateR1U8(absl::string_view value);
 
   // Creates a linspace-populated literal with the given number of rows and
   // columns.
-  static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
-                                                      int64 rows, int64 cols);
+  static Literal CreateR2F32Linspace(float from, float to, int64 rows,
+                                     int64 cols);
 
   // Creates a literal that projects the (x, y) dimensions given in values into
   // the z dimension given by "projection".
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR3Projected(
+  static Literal CreateR3Projected(
       std::initializer_list<std::initializer_list<NativeT>> values,
       int64 projection);
 
   // Creates a literal that projects the (x, y) dimensions given in values into
   // the z and p dimensions given.
   template <typename NativeT>
-  static std::unique_ptr<Literal> CreateR4Projected(
+  static Literal CreateR4Projected(
       std::initializer_list<std::initializer_list<NativeT>> values,
       int64 projection_p, int64 projection_z);
 
   // Returns an identity matrix (rank 2) with the given row and column count.
   template <typename NativeT>
-  static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+  static Literal MakeIdentityR2(int64 size);
 
   // Returns a tuple literal composed of given literals. Data is copied from the
   // given elements into the returned literal.
-  static std::unique_ptr<Literal> MakeTuple(
-      absl::Span<const Literal* const> elements);
+  static Literal MakeTuple(absl::Span<const Literal* const> elements);
 
-  static std::unique_ptr<Literal> MakeTupleFromSlices(
-      absl::Span<const LiteralSlice> elements);
+  static Literal MakeTupleFromSlices(absl::Span<const LiteralSlice> elements);
 
   // As above, but intended to be invoked with move semantics; i.e.
   //
-  //  std::vector<std::unique_ptr<Literal>> elements = ...;
+  //  std::vector<Literal> elements = ...;
   //  auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
   //
   // This would have been declared as an overload, but there is ambiguity
   // in invocation between the above signature and this one.
-  static std::unique_ptr<Literal> MakeTupleOwned(
-      std::vector<std::unique_ptr<Literal>> elements);
+  static Literal MakeTupleOwned(std::vector<Literal> elements);
 
-  // This overload lets you pass a braced list of unique_ptr<Literal>s to
+  // This overload lets you pass a braced list of Literals to
   // MakeTupleOwned:
   //
   //   LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
   //
-  // Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
+  // Simply relying on the MakeTupleOwned(std::vector<Literal>)
   // overload doesn't work because std::initializer_list's elements are always
   // const.
   //
-  // The arguments to this function must all be unique_ptr<Literal>.
+  // The arguments to this function must all be Literal.
   template <typename... Ts>
-  static std::unique_ptr<Literal> MakeTupleOwned(
-      std::unique_ptr<Ts>... elements) {
-    std::array<std::unique_ptr<Literal>, sizeof...(Ts)> arr{
-        std::move(elements)...};
-    std::vector<std::unique_ptr<Literal>> v;
+  static Literal MakeTupleOwned(Ts... elements) {
+    std::array<Literal, sizeof...(Ts)> arr{std::move(elements)...};
+    std::vector<Literal> v;
     v.insert(v.begin(), std::make_move_iterator(arr.begin()),
              std::make_move_iterator(arr.end()));
     return MakeTupleOwned(std::move(v));
   }
 
   // Create a constant token literal. Token types have no value.
-  static std::unique_ptr<Literal> CreateToken();
+  static Literal CreateToken();
 
   // Creates a new Literal object with its values havings the primitive_type
   // type, and with dimensions defined by the dimensions parameter.
   // The content of the literal values is the default value of the primitive
   // type of literal itself (0 for numeric types, and false for predicates).
-  static std::unique_ptr<Literal> CreateFromDimensions(
-      PrimitiveType primitive_type, absl::Span<const int64> dimensions);
+  static Literal CreateFromDimensions(PrimitiveType primitive_type,
+                                      absl::Span<const int64> dimensions);
 
   // If the given literal's data type is bfloat16, converts it to a float
   // literal; otherwise, returns a copy of it. If the literal is a tuple,
   // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertBF16ToF32(
-      const LiteralSlice& bf16_literal);
+  static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal);
 
   // If the given literal's data type is float, converts it to a bfloat16
   // literal; otherwise, returns a copy of it. If the literal is a tuple,
   // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertF32ToBF16(
-      const LiteralSlice& f32_literal);
+  static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal);
 
   // Creates a literal with a new shape with the given new dimensions using the
   // data in the given input literal. For reshaping purposes the (flat) data
   // buffer of the input literal is assumed to have the given minor_to_major
   // layout order.
-  static std::unique_ptr<Literal> ReshapeSlice(
-      absl::Span<const int64> new_dimensions,
-      absl::Span<const int64> minor_to_major, const LiteralSlice& literal);
+  static Literal ReshapeSlice(absl::Span<const int64> new_dimensions,
+                              absl::Span<const int64> minor_to_major,
+                              const LiteralSlice& literal);
 
   // Creates a literal with the supplied shape, and uses the provided value
   // generator to populate the literal's values.
@@ -286,7 +275,7 @@
   template <
       PrimitiveType type,
       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+  static StatusOr<Literal> CreateRandomLiteral(
       const Shape& shape,
       const std::function<T(absl::Span<const int64>)>& generator);
 
@@ -297,8 +286,8 @@
   template <
       PrimitiveType type, typename E,
       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
-      const Shape& shape, E* engine, T mean, T stddev);
+  static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, E* engine,
+                                               T mean, T stddev);
 
   // Creates a literal with the supplied shape, and initializes the literal
   // values using a normal distribution with given mean and stddev standard
@@ -307,8 +296,8 @@
   template <
       PrimitiveType type,
       typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
-      const Shape& shape, T mean, T stddev);
+  static StatusOr<Literal> CreateRandomLiteral(const Shape& shape, T mean,
+                                               T stddev);
 
   //
   // End of factory methods.
@@ -322,44 +311,43 @@
 std::ostream& operator<<(std::ostream& out, const Literal& literal);
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
-  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShape(
+/* static */ Literal LiteralUtil::CreateR0(NativeT value) {
+  Literal literal(ShapeUtil::MakeShape(
       primitive_util::NativeToPrimitiveType<NativeT>(), {}));
-  literal->Set({}, value);
+  literal.Set({}, value);
   return literal;
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
-    absl::Span<const NativeT> values) {
-  auto literal = absl::make_unique<Literal>(
+/* static */ Literal LiteralUtil::CreateR1(absl::Span<const NativeT> values) {
+  Literal literal(
       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
                            {static_cast<int64>(values.size())}));
-  literal->PopulateR1(values);
+  literal.PopulateR1(values);
   return literal;
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+/* static */ Literal LiteralUtil::CreateR2WithLayout(
     std::initializer_list<std::initializer_list<NativeT>> values,
     const Layout& layout) {
-  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+  Literal literal(ShapeUtil::MakeShapeWithLayout(
       primitive_util::NativeToPrimitiveType<NativeT>(),
       {static_cast<int64>(values.size()),
        static_cast<int64>(values.begin()->size())},
       AsInt64Slice(layout.minor_to_major())));
-  literal->PopulateR2(values);
+  literal.PopulateR2(values);
   return literal;
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+/* static */ Literal LiteralUtil::CreateR2(
     std::initializer_list<std::initializer_list<NativeT>> values) {
   return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+/* static */ Literal LiteralUtil::CreateR3WithLayout(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         values,
     const Layout& layout) {
@@ -384,14 +372,14 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+/* static */ Literal LiteralUtil::CreateR3(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         values) {
   return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+/* static */ Literal LiteralUtil::CreateR4WithLayout(
     std::initializer_list<std::initializer_list<
         std::initializer_list<std::initializer_list<NativeT>>>>
         values,
@@ -422,23 +410,22 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
+/* static */ Literal LiteralUtil::CreateSparse(
     absl::Span<const int64> dimensions, SparseIndexArray indices,
     absl::Span<const NativeT> values, bool sort) {
   int64 num_elements = values.size();
   int64 rank = dimensions.size();
   CHECK_EQ(num_elements, indices.index_count());
   CHECK_EQ(rank, indices.rank());
-  auto literal =
-      absl::make_unique<Literal>(ShapeUtil::MakeShapeWithSparseLayout(
-          primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
-          indices.max_indices()));
-  literal->PopulateSparse(indices, values, sort);
+  Literal literal(ShapeUtil::MakeShapeWithSparseLayout(
+      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions,
+      indices.max_indices()));
+  literal.PopulateSparse(indices, values, sort);
   return literal;
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+/* static */ Literal LiteralUtil::CreateR4(
     std::initializer_list<std::initializer_list<
         std::initializer_list<std::initializer_list<NativeT>>>>
         values) {
@@ -446,50 +433,48 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
+/* static */ Literal LiteralUtil::CreateFromArrayWithLayout(
     const Array<NativeT>& values, const Layout& layout) {
-  auto literal = absl::make_unique<Literal>(ShapeUtil::MakeShapeWithLayout(
+  Literal literal(ShapeUtil::MakeShapeWithLayout(
       primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
       AsInt64Slice(layout.minor_to_major())));
-  literal->PopulateFromArray(values);
+  literal.PopulateFromArray(values);
   return literal;
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
+/* static */ Literal LiteralUtil::CreateFromArray(
     const Array<NativeT>& values) {
   return CreateFromArrayWithLayout(
       values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
-                                           const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR2FromArray2DWithLayout(
+    const Array2D<NativeT>& values, const Layout& layout) {
   return CreateFromArrayWithLayout(values, layout);
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+/* static */ Literal LiteralUtil::CreateR2FromArray2D(
     const Array2D<NativeT>& values) {
   return CreateFromArray(values);
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
-                                           const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR3FromArray3DWithLayout(
+    const Array3D<NativeT>& values, const Layout& layout) {
   return CreateFromArrayWithLayout(values, layout);
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+/* static */ Literal LiteralUtil::CreateR3FromArray3D(
     const Array3D<NativeT>& values) {
   return CreateFromArray(values);
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+/* static */ Literal LiteralUtil::CreateR3Projected(
     std::initializer_list<std::initializer_list<NativeT>> values,
     int64 projection) {
   int64 dim0_size = projection;
@@ -514,7 +499,7 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+/* static */ Literal LiteralUtil::CreateR4Projected(
     std::initializer_list<std::initializer_list<NativeT>> values,
     int64 projection_p, int64 projection_z) {
   int64 dim0_size = projection_p;
@@ -542,21 +527,20 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+/* static */ Literal LiteralUtil::CreateR4FromArray4D(
     const Array4D<NativeT>& values) {
   return CreateFromArray(values);
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
-                                           const Layout& layout) {
+/* static */ Literal LiteralUtil::CreateR4FromArray4DWithLayout(
+    const Array4D<NativeT>& values, const Layout& layout) {
   return CreateFromArrayWithLayout(values, layout);
 }
 
 // Returns an identity matrix (rank 2) with the given row and column count.
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+/* static */ Literal LiteralUtil::MakeIdentityR2(int64 size) {
   Array2D<NativeT> array(size, size, 0);
   for (int64 i = 0; i < size; ++i) {
     array(i, i) = 1;
@@ -565,33 +549,29 @@
 }
 
 template <typename NativeT>
-/* static */ std::unique_ptr<Literal>
-LiteralUtil::CreateFullWithDescendingLayout(absl::Span<const int64> dimensions,
-                                            NativeT value) {
-  auto literal =
-      absl::make_unique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
-          primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
-  literal->PopulateWithValue(value);
+/* static */ Literal LiteralUtil::CreateFullWithDescendingLayout(
+    absl::Span<const int64> dimensions, NativeT value) {
+  Literal literal(ShapeUtil::MakeShapeWithDescendingLayout(
+      primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
+  literal.PopulateWithValue(value);
   return literal;
 }
 
 template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
     const Shape& shape,
     const std::function<T(absl::Span<const int64>)>& generator) {
   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
   TF_RET_CHECK(shape.element_type() == type);
-  auto literal = absl::make_unique<Literal>(shape);
-  TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+  Literal literal(shape);
+  TF_RETURN_IF_ERROR(literal.Populate<NativeT>(
       [&](absl::Span<const int64> indexes) { return generator(indexes); }));
   return std::move(literal);
 }
 
 template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
-                                 T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+    const Shape& shape, E* engine, T mean, T stddev) {
   using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
   std::normal_distribution<NativeT> generator(mean, stddev);
   return CreateRandomLiteral<type, NativeT>(
@@ -600,8 +580,8 @@
 }
 
 template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<Literal> LiteralUtil::CreateRandomLiteral(
+    const Shape& shape, T mean, T stddev) {
   std::minstd_rand0 engine;
   return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
 }
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index f9473d3..0f86f9f 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -39,8 +39,8 @@
 
 PackedLiteralReader::~PackedLiteralReader() { delete file_; }
 
-StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
-    const Shape& shape, const Layout* layout) {
+StatusOr<Literal> PackedLiteralReader::Read(const Shape& shape,
+                                            const Layout* layout) {
   VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
           << " layout: "
           << (layout == nullptr ? "<none>" : layout->ShortDebugString());
@@ -57,11 +57,11 @@
         PrimitiveType_Name(shape.element_type()));
   }
 
-  auto result = absl::make_unique<Literal>(literal_shape);
-  result->PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
+  Literal result(literal_shape);
+  result.PopulateWithValue(std::numeric_limits<float>::quiet_NaN());
 
   int64 elements = ShapeUtil::ElementsIn(shape);
-  absl::Span<const float> field = result->data<float>();
+  absl::Span<const float> field = result.data<float>();
   char* data = absl::bit_cast<char*>(field.data());
   uint64 bytes = elements * sizeof(float);
   absl::string_view sp;
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 98dccaa..d6d2ff1 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -41,8 +41,7 @@
   //
   // Layout is optional. If it is not provided, no layout is set on the literal
   // that is produced.
-  StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
-                                          const Layout* layout = nullptr);
+  StatusOr<Literal> Read(const Shape& shape, const Layout* layout = nullptr);
 
   // Returns whether the input file has been fully exhausted; i.e. all available
   // packed literals have been read and we're at the end of the file.
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd6e20b..9da5dc0 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -81,8 +81,8 @@
   return client->TransferToInfeedLocal(literal, device_ordinal);
 }
 
-StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
-    const Shape& shape, int replica_number) {
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+                                                  int replica_number) {
   VLOG(1) << "Outfeeding literal from replica number: " << replica_number
           << " shape: " << shape;
   LocalClient* client = GetOrCreateLocalClient();
@@ -141,9 +141,8 @@
   LocalClient* client = GetOrCreateLocalClient();
   StatusOr<ScopedShapedBuffer> buf = [&] {
     if (shape_with_layout) {
-      std::unique_ptr<Literal> relaid =
-          argument.Relayout(shape_with_layout.value());
-      return ToBuffer(client, /*device_ordinal=*/0, *relaid);
+      Literal relaid = argument.Relayout(shape_with_layout.value());
+      return ToBuffer(client, /*device_ordinal=*/0, relaid);
     }
     return ToBuffer(client, /*device_ordinal=*/0, argument);
   }();
@@ -151,7 +150,7 @@
   return new LocalShapedBuffer(std::move(buf).ValueOrDie());
 }
 
-StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
+StatusOr<Literal> LocalShapedBuffer::ToLiteral() const {
   LocalClient* client = GetOrCreateLocalClient();
   return client->ShapedBufferToLiteral(*shaped_buffer());
 }
@@ -160,7 +159,7 @@
     std::unique_ptr<LocalExecutable> executable)
     : executable_(std::move(executable)) {}
 
-StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
+StatusOr<Literal> CompiledLocalComputation::Execute(
     const std::vector<Literal>& arguments,
     const std::vector<absl::optional<Shape>>& shapes_with_layout) {
   LocalClient* client = GetOrCreateLocalClient();
@@ -169,7 +168,7 @@
 
   // Each replica populates a StatusOr result, but only replica zero actually
   // retrieves its literal value.
-  std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+  std::vector<StatusOr<Literal>> results(GetReplicaCount());
   {
     tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
                                         GetReplicaCount());
@@ -198,9 +197,8 @@
 
               StatusOr<ScopedShapedBuffer> pushed;
               if (shape_with_layout) {
-                std::unique_ptr<Literal> relaid =
-                    argument.Relayout(shape_with_layout.value());
-                pushed = ToBuffer(client, device_ordinal, *relaid);
+                Literal relaid = argument.Relayout(shape_with_layout.value());
+                pushed = ToBuffer(client, device_ordinal, relaid);
               } else {
                 pushed = ToBuffer(client, device_ordinal, argument);
               }
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 78b3c59..1d5dfe5 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -51,8 +51,8 @@
 // Transfers a literal of the given shape from the outfeed of the given replica.
 //
 // The replica number is resolved to an appropriate device ordinal.
-StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
-    const Shape& shape, int replica_number);
+StatusOr<Literal> TransferFromOutfeedLocalReplica(const Shape& shape,
+                                                  int replica_number);
 
 // Wraps a ScopedShapedBuffer produced by copying a literal "to
 // device," i.e. copying a literal to a scoped buffer via the local
@@ -65,7 +65,7 @@
   LocalShapedBuffer(ScopedShapedBuffer shaped_buffer);
   const ScopedShapedBuffer* shaped_buffer() const;
 
-  StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+  StatusOr<Literal> ToLiteral() const;
 
   // Transfers ownership of the encapsulated ShapedBuffer to the caller,
   // analogous to std::unique_ptr::release().
@@ -117,7 +117,7 @@
   // with optionally-specified argument layouts. The literals will be
   // re-laid out according to the corresponding elements of
   // shapes_with_layout.
-  StatusOr<std::unique_ptr<Literal> > Execute(
+  StatusOr<Literal> Execute(
       const std::vector<Literal>& arguments,
       const std::vector<absl::optional<Shape> >& shapes_with_layout);
 
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 76c0951..521490e 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,12 +109,12 @@
 // Must be included first
 #include "tensorflow/python/lib/core/numpy.h"
 
-#include "third_party/absl/strings/str_cat.h"
-#include "third_party/absl/strings/str_format.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "third_party/absl/types/span.h"
+#include "absl/types/span.h"
 #include "tensorflow/compiler/xla/python/numpy_bridge.h"
 #include "tensorflow/compiler/xla/python/local_computation_builder.h"
 
@@ -216,9 +216,9 @@
 }
 
 
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
   if ($1.ok()) {
-    std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
+    Literal value = $1.ConsumeValueOrDie();
     $result = numpy::PyObjectFromXlaLiteral(*value);
   } else {
     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
@@ -346,25 +346,25 @@
 
 // Literal
 
-%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
+%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
   literal_status = numpy::XlaLiteralFromPyObject($input);
   if (!literal_status.ok()) {
     PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
     SWIG_fail;
   }
-  $1 = literal_status.ValueOrDie().get();
+  $1 = &literal_status.ValueOrDie();
 }
 
-%typemap(out) std::unique_ptr<Literal> {
+%typemap(out) Literal {
   $result = numpy::PyObjectFromXlaLiteral(*$1);
 }
 
-%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+%typemap(out) StatusOr<Literal> {
   if (!$1.ok()) {
     PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
     SWIG_fail;
   }
-  $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+  $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
 }
 
 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
@@ -375,13 +375,13 @@
   const int size = PySequence_Size($input);
   for (int i = 0; i < size; ++i) {
     PyObject* o = PySequence_GetItem($input, i);
-    StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
+    StatusOr<Literal> literal_status = numpy::XlaLiteralFromPyObject(o);
     if (!literal_status.ok()) {
       PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
       Py_DECREF(o);
       SWIG_fail;
     }
-    temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
+    temps.push_back(literal_status.ConsumeValueOrDie());
     Py_DECREF(o);
   }
   $1 = &temps;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index fc6511b..b0aa024 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -368,10 +368,10 @@
   }
 }
 
-StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o) {
   if (PyTuple_Check(o)) {
     int num_elements = PyTuple_Size(o);
-    std::vector<std::unique_ptr<Literal>> elements;
+    std::vector<Literal> elements;
     elements.reserve(num_elements);
     for (int i = 0; i < num_elements; i++) {
       PyObject* element = PyTuple_GetItem(o, i);
@@ -389,8 +389,7 @@
     int np_type = PyArray_TYPE(py_array);
     auto literal = LiteralUtil::CreateFromDimensions(
         NumpyTypeToPrimitiveType(np_type), dimensions);
-    TF_RETURN_IF_ERROR(
-        CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
+    TF_RETURN_IF_ERROR(CopyNumpyArrayToLiteral(np_type, py_array, &literal));
     return std::move(literal);
   } else {
     return InvalidArgument(
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 8cae175..40ff2d9 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -82,7 +82,7 @@
 // To avoid transferring ownership of the data buffers that underlie
 // PyArrays and XLA literals, this function makes deep copies of all
 // array data.
-StatusOr<std::unique_ptr<Literal> > XlaLiteralFromPyObject(PyObject* o);
+StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
 
 // The following functions copy array data from the buffers underlying Numpy
 // ndarrays into those underlying XLA literals, and vice versa.
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index a4854f5..ceb5e74 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -186,11 +186,10 @@
 
 /* static  */ std::unique_ptr<std::vector<float>>
 ReferenceUtil::ReduceWindow1DGeneric(
-    const absl::Span<const float>& operand, float init,
+    absl::Span<const float> operand, float init,
     const std::function<float(float, float)>& reduce_func,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride,
-    const absl::Span<const std::pair<int64, int64>>& padding) {
+    absl::Span<const int64> window, absl::Span<const int64> stride,
+    absl::Span<const std::pair<int64, int64>> padding) {
   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
   std::vector<int64> window_counts(window.size(), 0);
   std::vector<int64> pad_low(window.size(), 0);
@@ -218,10 +217,9 @@
 }
 
 /* static  */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
-                                 float init,
-                                 const absl::Span<const int64>& window,
-                                 const absl::Span<const int64>& stride,
+ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
+                                 absl::Span<const int64> window,
+                                 absl::Span<const int64> stride,
                                  Padding padding) {
   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
@@ -234,9 +232,8 @@
 ReferenceUtil::ReduceWindow2DGeneric(
     const Array2D<float>& operand, float init,
     const std::function<float(float, float)>& reduce_func,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride,
-    const absl::Span<const std::pair<int64, int64>>& padding) {
+    absl::Span<const int64> window, absl::Span<const int64> stride,
+    absl::Span<const std::pair<int64, int64>> padding) {
   std::vector<int64> dim_lengths{operand.height(), operand.width()};
 
   std::vector<int64> window_counts(window.size(), 0);
@@ -273,9 +270,8 @@
 }
 
 /* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
-    const Array2D<float>& operand, float init,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride, Padding padding) {
+    const Array2D<float>& operand, float init, absl::Span<const int64> window,
+    absl::Span<const int64> stride, Padding padding) {
   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
   std::vector<int64> dim_lengths{operand.height(), operand.width()};
   return ReduceWindow2DGeneric(
@@ -284,9 +280,8 @@
 }
 
 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
-    const Array3D<float>& operand, float init,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride, Padding padding) {
+    const Array3D<float>& operand, float init, absl::Span<const int64> window,
+    absl::Span<const int64> stride, Padding padding) {
   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
 
@@ -332,8 +327,8 @@
 ReferenceUtil::ReduceWindow4DGeneric(
     const Array4D<float>& operand, float init,
     const std::function<float(float, float)>& reduce_func,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride, Padding padding) {
+    absl::Span<const int64> window, absl::Span<const int64> stride,
+    Padding padding) {
   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
                                  operand.n4()};
   return ReduceWindow4DGeneric(
@@ -345,9 +340,8 @@
 ReferenceUtil::ReduceWindow4DGeneric(
     const Array4D<float>& operand, float init,
     const std::function<float(float, float)>& reduce_func,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride,
-    const absl::Span<const std::pair<int64, int64>>& padding) {
+    absl::Span<const int64> window, absl::Span<const int64> stride,
+    absl::Span<const std::pair<int64, int64>> padding) {
   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
                                  operand.n4()};
 
@@ -399,9 +393,8 @@
 }
 
 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
-    const Array4D<float>& operand, float init,
-    const absl::Span<const int64>& window,
-    const absl::Span<const int64>& stride, Padding padding) {
+    const Array4D<float>& operand, float init, absl::Span<const int64> window,
+    absl::Span<const int64> stride, Padding padding) {
   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
                                padding);
@@ -425,8 +418,8 @@
 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
                                         const Array4D<float>& source,
                                         float init,
-                                        const absl::Span<const int64>& window,
-                                        const absl::Span<const int64>& stride,
+                                        absl::Span<const int64> window,
+                                        absl::Span<const int64> stride,
                                         bool same_padding) {
   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
   auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
@@ -529,13 +522,13 @@
   }
 
   ordered_input_dimensions[0] =
-      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
+      lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
   ordered_input_dimensions[1] =
-      lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
+      lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
   ordered_kernel_dimensions[0] =
-      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
+      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
   ordered_kernel_dimensions[1] =
-      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
+      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
 
   std::vector<std::pair<int64, int64>> paddings =
       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
@@ -546,7 +539,7 @@
 
   WindowDimension dim;
   dim.set_size(
-      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
+      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
   dim.set_stride(kernel_stride.first);
   dim.set_padding_low(paddings[0].first);
   dim.set_padding_high(paddings[0].second);
@@ -556,7 +549,7 @@
 
   WindowDimension dim2;
   dim2.set_size(
-      rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
+      rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
   dim2.set_stride(kernel_stride.second);
   dim2.set_padding_low(paddings[1].first);
   dim2.set_padding_high(paddings[1].second);
@@ -564,35 +557,39 @@
   dim2.set_base_dilation(lhs_dilation.second);
   *window.add_dimensions() = dim2;
 
-  const Shape& shape =
-      ShapeInference::InferConvolveShape(lhs_literal->shape(),
-                                         rhs_literal->shape(), window, dnums)
-          .ConsumeValueOrDie();
+  const Shape& shape = ShapeInference::InferConvolveShape(
+                           lhs_literal.shape(), rhs_literal.shape(),
+                           /*feature_group_count=*/1, window, dnums)
+                           .ConsumeValueOrDie();
 
   HloInstruction* lhs_instruction =
       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
   HloInstruction* rhs_instruction =
       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
 
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      /*new_size=*/2, PrecisionConfig::DEFAULT);
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, precision_config));
   HloModuleConfig config;
   HloModule module("ReferenceUtil", config);
   auto computation = module.AddEntryComputation(b.Build());
 
   HloEvaluator evaluator;
-  std::unique_ptr<Literal> result_literal =
+  Literal result_literal =
       evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
 
-  CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
+  CHECK_EQ(ShapeUtil::Rank(result_literal.shape()), 4);
   auto result =
-      absl::make_unique<Array4D<float>>(result_literal->shape().dimensions(0),
-                                        result_literal->shape().dimensions(1),
-                                        result_literal->shape().dimensions(2),
-                                        result_literal->shape().dimensions(3));
+      absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
+                                        result_literal.shape().dimensions(1),
+                                        result_literal.shape().dimensions(2),
+                                        result_literal.shape().dimensions(3));
 
   result->Each([&](absl::Span<const int64> indices, float* value) {
-    *value = result_literal->Get<float>(indices);
+    *value = result_literal.Get<float>(indices);
   });
 
   return result;
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 9ce0980..8654fbb 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -177,47 +177,41 @@
 
   // Windowed reductions with Add as the function to apply.
   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
-      const absl::Span<const float>& operand, float init,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, Padding padding);
+      absl::Span<const float> operand, float init,
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      Padding padding);
   static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
-      const Array2D<float>& operand, float init,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, Padding padding);
+      const Array2D<float>& operand, float init, absl::Span<const int64> window,
+      absl::Span<const int64> stride, Padding padding);
   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
-      const Array3D<float>& operand, float init,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, Padding padding);
+      const Array3D<float>& operand, float init, absl::Span<const int64> window,
+      absl::Span<const int64> stride, Padding padding);
   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
-      const Array4D<float>& operand, float init,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, Padding padding);
+      const Array4D<float>& operand, float init, absl::Span<const int64> window,
+      absl::Span<const int64> stride, Padding padding);
 
   // Windowed reductions with a generic reduce function.
   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
-      const absl::Span<const float>& operand, float init,
+      absl::Span<const float> operand, float init,
       const std::function<float(float, float)>& reduce_func,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride,
-      const absl::Span<const std::pair<int64, int64>>& padding);
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      absl::Span<const std::pair<int64, int64>> padding);
   static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
       const Array2D<float>& operand, float init,
       const std::function<float(float, float)>& reduce_func,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride,
-      const absl::Span<const std::pair<int64, int64>>& padding);
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      absl::Span<const std::pair<int64, int64>> padding);
   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
       const Array4D<float>& operand, float init,
       const std::function<float(float, float)>& reduce_func,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, Padding padding);
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      Padding padding);
   // With arbitrary padding.
   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
       const Array4D<float>& operand, float init,
       const std::function<float(float, float)>& reduce_func,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride,
-      const absl::Span<const std::pair<int64, int64>>& padding);
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      absl::Span<const std::pair<int64, int64>> padding);
 
   // Batch normalize data.
   static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -230,8 +224,8 @@
   // TODO(b/74533103) Switch tests to evaluator and remove this implementation.
   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
       const Array4D<float>& operand, const Array4D<float>& source, float init,
-      const absl::Span<const int64>& window,
-      const absl::Span<const int64>& stride, bool same_padding);
+      absl::Span<const int64> window, absl::Span<const int64> stride,
+      bool same_padding);
 
   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -332,8 +326,8 @@
 
   // Slices with index clamping
   template <typename T>
-  static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
-                                     int64 start, int64 size) {
+  static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
+                                     int64 size) {
     start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
     std::vector<T> result;
     for (int64 i = 0; i < size; ++i) {
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 3ec0192..a1b0f40 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -55,7 +55,7 @@
   auto result = ReferenceUtil::TransposeArray2D(*matrix_);
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
   LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
-                                       *actual_literal, ErrorSpec(0.0001));
+                                       actual_literal, ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, MatmulArray2D) {
@@ -67,14 +67,14 @@
   auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
   LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
-                                       *actual_literal, ErrorSpec(0.0001));
+                                       actual_literal, ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
   auto add = [](float lhs, float rhs) { return lhs + rhs; };
   auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
   auto actual_literal = LiteralUtil::CreateR1<float>(*result);
-  LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
+  LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, actual_literal,
                                        ErrorSpec(0.0001));
 }
 
@@ -82,7 +82,7 @@
   auto add = [](float lhs, float rhs) { return lhs + rhs; };
   auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
   auto actual_literal = LiteralUtil::CreateR1<float>(*result);
-  LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
+  LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, actual_literal,
                                        ErrorSpec(0.0001));
 }
 
@@ -90,14 +90,14 @@
   auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
       Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
       [](float a, float b) { return a + b; }));
-  LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
+  LiteralTestUtil::ExpectR1Equal<float>({0}, result);
 }
 
 TEST_F(ReferenceUtilTest, MapArray2D) {
   auto identity = [](float value) { return log(exp(value)); };
   auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
-  LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
+  LiteralTestUtil::ExpectR2NearArray2D(*matrix_, actual_literal,
                                        ErrorSpec(0.0001));
 }
 
@@ -108,7 +108,7 @@
   auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
   LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
-                                       *actual_literal, ErrorSpec(0.0001));
+                                       actual_literal, ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, MapArray4D) {
@@ -121,7 +121,7 @@
 
   Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
   expected.FillWithMultiples(2.0f);
-  LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
                                        ErrorSpec(0.0001));
 }
 
@@ -138,7 +138,7 @@
 
   Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
   expected.Fill(0.0f);
-  LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D(expected, actual_literal,
                                        ErrorSpec(0.0001));
 }
 
@@ -146,16 +146,16 @@
   auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 
-  LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
-                                       *actual_literal, ErrorSpec(0.0001));
+  LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, actual_literal,
+                                       ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
   auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
 
-  LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
-                                       *actual_literal, ErrorSpec(0.0001));
+  LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, actual_literal,
+                                       ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, SliceArray3D) {
@@ -167,7 +167,7 @@
   auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
 
   LiteralTestUtil::ExpectR3Near<float>(
-      {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+      {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, actual_literal,
       ErrorSpec(0.0001));
 }
 
@@ -180,8 +180,8 @@
   auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
 
   LiteralTestUtil::ExpectR3Near<float>(
-      {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
-      *actual_literal, ErrorSpec(0.0001));
+      {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, actual_literal,
+      ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, SliceArray4D) {
@@ -194,7 +194,7 @@
 
   LiteralTestUtil::ExpectR4Near<float>(
       {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
-      *actual_literal, ErrorSpec(0.0001));
+      actual_literal, ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
@@ -208,7 +208,7 @@
   LiteralTestUtil::ExpectR4Near<float>(
       {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
         {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
-      *actual_literal, ErrorSpec(0.0001));
+      actual_literal, ErrorSpec(0.0001));
 }
 
 TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
@@ -220,7 +220,7 @@
 
   auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
 
-  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -233,7 +233,7 @@
 
   auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
 
-  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR3NearArray3D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -268,7 +268,7 @@
 
   auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 
-  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -302,7 +302,7 @@
 
   auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 
-  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -358,7 +358,7 @@
 
   auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 
-  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -411,7 +411,7 @@
 
   auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
 
-  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+  LiteralTestUtil::ExpectR4NearArray4D<float>(expected, actual_literal,
                                               ErrorSpec(0.0001));
 }
 
@@ -424,7 +424,7 @@
       [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
   auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
   LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
-                                *actual_literal, ErrorSpec(0.0001));
+                                actual_literal, ErrorSpec(0.0001));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 43fd8fe..84fe5b1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -95,12 +95,11 @@
   std::vector<float> expected = {
       1.85840735, -1.85840735, 2.28318531,   -2.28318531,  -6.42477796,
       6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
-  std::unique_ptr<Literal> expected_literal =
-      LiteralUtil::CreateR1<float>(expected);
+  Literal expected_literal = LiteralUtil::CreateR1<float>(expected);
   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
   TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
                                                    computation, {}, nullptr));
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+  EXPECT_TRUE(LiteralTestUtil::Near(expected_literal, result_literal,
                                     ErrorSpec(0.0001)));
 }
 
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 26b48cf..fb80c78 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -87,6 +87,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
     ],
@@ -123,6 +124,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
     ],
@@ -159,6 +161,7 @@
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",  # fixdeps: keep
     ],
@@ -291,6 +294,7 @@
         "hlo_instructions.cc",
         "hlo_module.cc",
         "hlo_opcode.cc",
+        "hlo_schedule.cc",
         "hlo_sharding.cc",
     ],
     hdrs = [
@@ -303,6 +307,7 @@
         "hlo_instructions.h",
         "hlo_module.h",
         "hlo_opcode.h",
+        "hlo_schedule.h",
         "hlo_sharding.h",
     ],
     deps = [
@@ -331,6 +336,8 @@
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/strings:str_format",
+        "@com_google_absl//absl/types:optional",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -347,6 +354,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
     ],
@@ -397,6 +405,7 @@
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
 )
@@ -493,6 +502,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
     ],
@@ -541,6 +551,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
@@ -563,6 +574,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
     ],
@@ -989,6 +1001,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -1006,8 +1019,8 @@
         ":buffer_value_containers",
         ":heap_simulator",
         ":hlo",
+        ":hlo_memory_scheduler",
         ":hlo_proto",
-        ":hlo_scheduling",
         ":logical_buffer",
         ":tuple_points_to_analysis",
         "//tensorflow/compiler/xla:shape_util",
@@ -1035,8 +1048,8 @@
         ":cpu_plugin",
         ":flatten_call_graph",
         ":hlo",
+        ":hlo_memory_scheduler",
         ":hlo_ordering",
-        ":hlo_scheduling",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
@@ -1049,6 +1062,7 @@
         "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "//tensorflow/core:test",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -1081,14 +1095,15 @@
     deps = [
         ":hlo",
         ":hlo_dataflow_analysis",
+        ":hlo_memory_scheduler",
         ":hlo_ordering",
-        ":hlo_scheduling",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
     ],
 )
 
@@ -1123,13 +1138,46 @@
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "//tensorflow/core:test",
         "@com_google_absl//absl/memory",
     ],
 )
 
 cc_library(
+    name = "hlo_module_group",
+    srcs = ["hlo_module_group.cc"],
+    hdrs = ["hlo_module_group.h"],
+    deps = [
+        ":hlo",
+        ":hlo_proto",
+        "//tensorflow/compiler/xla:statusor",
+        "@com_google_absl//absl/strings",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
+tf_cc_test(
+    name = "hlo_module_group_test",
+    srcs = ["hlo_module_group_test.cc"],
+    deps = [
+        ":hlo",
+        ":hlo_matchers",
+        ":hlo_module_group",
+        ":hlo_parser",
+        ":hlo_proto",
+        "//tensorflow/compiler/xla:test",
+        "//tensorflow/compiler/xla:util",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:test",
+    ],
+)
+
+cc_library(
     name = "hlo_module_group_metadata",
     srcs = ["hlo_module_group_metadata.cc"],
     hdrs = ["hlo_module_group_metadata.h"],
@@ -1169,14 +1217,35 @@
     ],
 )
 
+tf_cc_test(
+    name = "hlo_schedule_test",
+    srcs = ["hlo_schedule_test.cc"],
+    deps = [
+        ":heap_simulator",
+        ":hlo",
+        ":hlo_dce",
+        ":hlo_memory_scheduler",
+        ":hlo_ordering",
+        ":hlo_parser",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/core:test",
+        "@com_google_absl//absl/algorithm:container",
+    ],
+)
+
 cc_library(
-    name = "hlo_scheduling",
-    srcs = ["hlo_scheduling.cc"],
-    hdrs = ["hlo_scheduling.h"],
+    name = "hlo_memory_scheduler",
+    srcs = ["hlo_memory_scheduler.cc"],
+    hdrs = ["hlo_memory_scheduler.h"],
     deps = [
         ":heap_simulator",
         ":hlo",
         ":hlo_ordering",
+        ":hlo_pass",
         ":logical_buffer",
         ":tuple_points_to_analysis",
         "//tensorflow/compiler/xla:shape_util",
@@ -1190,21 +1259,22 @@
 )
 
 tf_cc_test(
-    name = "hlo_scheduling_test",
-    srcs = ["hlo_scheduling_test.cc"],
+    name = "hlo_memory_scheduler_test",
+    srcs = ["hlo_memory_scheduler_test.cc"],
     deps = [
         ":heap_simulator",
         ":hlo",
         ":hlo_dce",
+        ":hlo_memory_scheduler",
         ":hlo_ordering",
         ":hlo_parser",
-        ":hlo_scheduling",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/algorithm:container",
     ],
 )
 
@@ -1229,6 +1299,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/memory",
     ],
 )
 
@@ -1362,6 +1433,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
@@ -1678,6 +1750,7 @@
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/core:test",
     ],
 )
@@ -1747,6 +1820,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "@com_google_absl//absl/memory",
@@ -1922,6 +1996,9 @@
     srcs = ["hlo_module_test.cc"],
     deps = [
         ":hlo",
+        ":hlo_matchers",
+        ":hlo_memory_scheduler",
+        ":hlo_parser",
         "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
@@ -1930,6 +2007,7 @@
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "//tensorflow/core:test",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
@@ -2203,6 +2281,7 @@
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
@@ -2281,6 +2360,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/core:test",
     ],
 )
@@ -2361,12 +2441,11 @@
         ":buffer_liveness",
         ":buffer_value",
         ":call_graph",
-        ":copy_insertion",
         ":flatten_call_graph",
         ":hlo",
         ":hlo_dce",
+        ":hlo_memory_scheduler",
         ":hlo_ordering",
-        ":hlo_scheduling",
         ":logical_buffer",
         ":tuple_points_to_analysis",
         "//tensorflow/compiler/xla:shape_util",
@@ -2395,6 +2474,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
     ],
@@ -2461,6 +2541,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
@@ -2520,6 +2601,7 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/container:inlined_vector",
     ],
 )
 
@@ -2538,6 +2620,7 @@
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:lib",
@@ -2576,6 +2659,7 @@
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
@@ -2853,6 +2937,7 @@
     deps = [
         ":hlo_tfgraph_builder",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:protos_all_cc",
     ],
@@ -3187,6 +3272,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/container:inlined_vector",
     ],
 )
 
@@ -3289,6 +3375,8 @@
     size = "small",
     srcs = ["hlo_parser_test.cc"],
     deps = [
+        ":hlo",
+        ":hlo_casting_utils",
         ":hlo_matchers",
         ":hlo_parser",
         "//tensorflow/compiler/xla:window_util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 7c078f0..5458159 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -205,7 +205,7 @@
   HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
     HloInstruction* zero =
         computation_->AddInstruction(HloInstruction::CreateConstant(
-            LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
+            LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
     HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
     Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
     return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -296,6 +296,14 @@
     return scalar_add_computation_;
   }
 
+  // Tries to fold a kPad in the input or filter into the convolution
+  // instruction's window.
+  StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
+  StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
+
+  // Tries to use a kDot in place of the given convolution.
+  StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
+
   // Current HloComputation instance the AlgebraicSimplifierVisitor is
   // traversing.
   HloComputation* computation_;
@@ -312,7 +320,8 @@
   // Disable dot strength reduction on platforms where it causes a slowdown.
   bool enable_dot_strength_reduction_;
 
-  // Disable convolution simplification on platforms where it causes a slowdown.
+  // Disable convolution -> dot simplification on platforms where it causes a
+  // slowdown.
   bool enable_conv_simplification_;
 
   // Cached computation for adding two scalar F32.
@@ -527,7 +536,7 @@
     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
   } else {
     return computation->AddInstruction(
-        HloInstruction::CreateConstant(literal.CloneToUnique()));
+        HloInstruction::CreateConstant(literal.Clone()));
   }
 }
 
@@ -546,7 +555,7 @@
   // If a literal is all the same element replace it with a scalar broadcast.
   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
       constant->literal().IsAllFirst()) {
-    std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>(
+    Literal unique_scalar(
         LiteralUtil::GetFirstScalarLiteral(constant->literal()));
     HloInstruction* scalar = computation_->AddInstruction(
         HloInstruction::CreateConstant(std::move(unique_scalar)));
@@ -676,7 +685,7 @@
         return Status::OK();
     }
     auto inverse = computation_->AddInstruction(
-        HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+        HloInstruction::CreateConstant((new_literal.Clone())));
     TF_ASSIGN_OR_RETURN(auto new_divide,
                         MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
     return ReplaceInstruction(divide, new_divide);
@@ -950,9 +959,9 @@
       new_dot_rhs = rhs_slice;
     }
 
-    auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
-        dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums));
-    new_dot->set_precision_config(dot.precision_config());
+    auto* new_dot = computation_->AddInstruction(
+        HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs,
+                                  new_dot_dnums, dot.precision_config()));
 
     if (add_result) {
       add_result = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -1053,9 +1062,9 @@
   const int n =
       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
   auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
-  auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot(
-      memoized_shape, left_operand, right_operand, dnums));
-  memoized_inst->set_precision_config(dot->precision_config());
+  auto* memoized_inst = computation_->AddInstruction(
+      HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
+                                dnums, dot->precision_config()));
   // Get pair {start, 0} or {0, start}.
   HloInstruction* original_start_indices =
       lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1);
@@ -1151,9 +1160,8 @@
     dot_dimension_numbers.add_rhs_contracting_dimensions(0);
     auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
         ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
-        rhs->mutable_operand(0), lhs->mutable_operand(0),
-        dot_dimension_numbers));
-    new_dot->set_precision_config(dot->precision_config());
+        rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers,
+        dot->precision_config()));
     return ReplaceWithNewInstruction(
         dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
   }
@@ -1470,7 +1478,7 @@
   auto* iota = Cast<HloIotaInstruction>(instruction);
   if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
     auto zero = computation_->AddInstruction(HloInstruction::CreateConstant(
-        LiteralUtil::Zero(iota->shape().element_type()).CloneToUnique()));
+        LiteralUtil::Zero(iota->shape().element_type()).Clone()));
     return ReplaceWithNewInstruction(
         iota, HloInstruction::CreateBroadcast(iota->shape(), zero, {}));
   }
@@ -1573,7 +1581,7 @@
   CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
   if (IsAll(rhs, 0)) {
     auto one = HloInstruction::CreateConstant(
-        LiteralUtil::One(power->shape().element_type()).CloneToUnique());
+        LiteralUtil::One(power->shape().element_type()).Clone());
     std::unique_ptr<HloInstruction> ones;
     if (ShapeUtil::IsScalar(power->shape())) {
       ones = std::move(one);
@@ -1608,7 +1616,7 @@
   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
   if (IsAll(rhs, -1)) {
     auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
-        LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
+        LiteralUtil::One(rhs->shape().element_type()).Clone()));
 
     // Explicitly broadcast scalar 1 to the output shape, to avoid implicit
     // broadcast in divide HLO as we are trying to eliminate implicit
@@ -2058,12 +2066,12 @@
       if (pad_literal == reduce_init_literal) {
         return true;
       }
-      auto converted_pad_literal = pad_literal.ConvertToShape(
-          reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+      auto converted_pad_literal =
+          pad_literal.ConvertToShape(reduce_init_value->shape());
       if (!converted_pad_literal.ok()) {
         return false;
       }
-      return *converted_pad_literal.ValueOrDie() == reduce_init_literal;
+      return converted_pad_literal.ValueOrDie() == reduce_init_literal;
     };
     // The pad value is usually a constant, so we handle that case and do not
     // try to get more fancy about proving equivalence in cases beyond that.
@@ -2213,170 +2221,155 @@
   return Status::OK();
 }
 
-Status AlgebraicSimplifierVisitor::HandleConvolution(
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
     HloInstruction* convolution) {
-  auto lhs = convolution->mutable_operand(0);
-  auto rhs = convolution->mutable_operand(1);
-  if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
-      ShapeUtil::IsZeroElementArray(rhs->shape())) {
-    return ReplaceWithNewInstruction(
-        convolution,
-        HloInstruction::CreateBroadcast(
-            convolution->shape(),
-            computation_->AddInstruction(HloInstruction::CreateConstant(
-                LiteralUtil::Zero(convolution->shape().element_type())
-                    .CloneToUnique())),
-            {}));
-  }
-
+  auto* lhs = convolution->mutable_operand(0);
+  auto* rhs = convolution->mutable_operand(1);
   const auto& window = convolution->window();
   const ConvolutionDimensionNumbers& dnums =
       convolution->convolution_dimension_numbers();
 
-  // Try to merge padding/dilation of the input with the convolution's window.
-  TF_ASSIGN_OR_RETURN(bool folded_input_pad, [&]() -> StatusOr<bool> {
-    if (lhs->opcode() != HloOpcode::kPad) {
-      return false;
-    }
-
-    // Convolution's padding is always zero, so bail if the kPad is adding
-    // something other than zero.
-    if (!IsAll(lhs->operand(1), 0)) {
-      return false;
-    }
-
-    const auto& padding = lhs->padding_config();
-
-    // Can't pad batch or feature dims.
-    for (int64 dim :
-         {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
-      const auto& p = padding.dimensions(dim);
-      if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
-          p.interior_padding() != 0) {
-        return false;
-      }
-    }
-
-    // Compute the window which is the result of merging the kPad and the
-    // convolution's existing window.
-    Window new_window = window;
-    for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
-      auto& w = *new_window.mutable_dimensions(dim);
-      const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
-      // Edge padding composes with itself in the straightforward way, but
-      // composing interior padding is nontrivial, and we cowardly refuse to
-      // think about it. If we see interior padding in either the kPad or conv,
-      // bail if there's any sort of padding in the other.
-      if (p.interior_padding() != 0 &&
-          (w.padding_low() != 0 || w.padding_high() != 0 ||
-           w.base_dilation() != 1)) {
-        return false;
-      }
-      if (w.base_dilation() != 1 &&
-          (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
-           p.interior_padding() != 0)) {
-        return false;
-      }
-
-      w.set_padding_low(w.padding_low() + p.edge_padding_low());
-      w.set_padding_high(w.padding_high() + p.edge_padding_high());
-      if (p.interior_padding() != 0) {
-        CHECK_EQ(w.base_dilation(), 1);
-        w.set_base_dilation(1 + p.interior_padding());
-      }
-    }
-
-    auto new_conv = convolution->CloneWithNewOperands(
-        convolution->shape(), {lhs->mutable_operand(0), rhs});
-    new_conv->set_window(new_window);
-    TF_RETURN_IF_ERROR(
-        ReplaceWithNewInstruction(convolution, std::move(new_conv)));
-    return true;
-  }());
-
-  if (folded_input_pad) {
-    return Status::OK();
+  if (lhs->opcode() != HloOpcode::kPad) {
+    return false;
   }
 
-  // Try to merge dilation of the filter with the convolution's window.
-  TF_ASSIGN_OR_RETURN(bool folded_filter_pad, [&]() -> StatusOr<bool> {
-    if (rhs->opcode() != HloOpcode::kPad) {
-      return false;
-    }
-
-    // Convolution's padding is always zero, so bail if the kPad is adding
-    // something other than zero.
-    if (!IsAll(rhs->operand(1), 0)) {
-      return false;
-    }
-
-    const auto& padding = rhs->padding_config();
-
-    // Can't pad or dilate feature dims.
-    for (int64 dim : {dnums.kernel_input_feature_dimension(),
-                      dnums.kernel_output_feature_dimension()}) {
-      const auto& p = padding.dimensions(dim);
-      if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
-          p.interior_padding() != 0) {
-        return false;
-      }
-    }
-
-    // Compute the window which is the result of merging the kPad and the
-    // convolution's existing window.
-    Window new_window = convolution->window();
-    for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
-      auto& w = *new_window.mutable_dimensions(dim);
-      const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
-
-      // We can only do this transformation if p adds dilation to the filter --
-      // edge padding on the filter is not supported in conv.
-      if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
-        return false;
-      }
-
-      // Nothing to do if the kPad for this dim is entirely a nop.
-      if (p.interior_padding() == 0) {
-        continue;
-      }
-
-      // We cowardly refuse to think about how dilation composes with itself;
-      // bail if both the kPad and conv have dilation on this dimension.
-      if (w.window_dilation() > 1) {
-        return false;
-      }
-      CHECK_EQ(w.window_dilation(), 1);
-      w.set_window_dilation(1 + p.interior_padding());
-      w.set_size(rhs->operand(0)->shape().dimensions(
-          dnums.kernel_spatial_dimensions(dim)));
-    }
-
-    auto new_conv = convolution->CloneWithNewOperands(
-        convolution->shape(), {lhs, rhs->mutable_operand(0)});
-    new_conv->set_window(new_window);
-    TF_RETURN_IF_ERROR(
-        ReplaceWithNewInstruction(convolution, std::move(new_conv)));
-    return true;
-  }());
-
-  if (folded_filter_pad) {
-    return Status::OK();
+  // Convolution's padding is always zero, so bail if the kPad is adding
+  // something other than zero.
+  if (!IsAll(lhs->operand(1), 0)) {
+    return false;
   }
 
+  const auto& padding = lhs->padding_config();
+
+  // Can't pad batch or feature dims.
+  for (int64 dim :
+       {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
+    const auto& p = padding.dimensions(dim);
+    if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+        p.interior_padding() != 0) {
+      return false;
+    }
+  }
+
+  // Compute the window which is the result of merging the kPad and the
+  // convolution's existing window.
+  Window new_window = window;
+  for (int64 dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
+    auto& w = *new_window.mutable_dimensions(dim);
+    const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
+    // Edge padding composes with itself in the straightforward way, but
+    // composing interior padding is nontrivial, and we cowardly refuse to
+    // think about it. If we see interior padding in either the kPad or conv,
+    // bail if there's any sort of padding in the other.
+    if (p.interior_padding() != 0 &&
+        (w.padding_low() != 0 || w.padding_high() != 0 ||
+         w.base_dilation() != 1)) {
+      return false;
+    }
+    if (w.base_dilation() != 1 &&
+        (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+         p.interior_padding() != 0)) {
+      return false;
+    }
+
+    w.set_padding_low(w.padding_low() + p.edge_padding_low());
+    w.set_padding_high(w.padding_high() + p.edge_padding_high());
+    if (p.interior_padding() != 0) {
+      CHECK_EQ(w.base_dilation(), 1);
+      w.set_base_dilation(1 + p.interior_padding());
+    }
+  }
+
+  auto new_conv = convolution->CloneWithNewOperands(
+      convolution->shape(), {lhs->mutable_operand(0), rhs});
+  new_conv->set_window(new_window);
+  TF_RETURN_IF_ERROR(
+      ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+  return true;
+}
+
+StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
+    HloInstruction* convolution) {
+  auto* lhs = convolution->mutable_operand(0);
+  auto* rhs = convolution->mutable_operand(1);
+  const ConvolutionDimensionNumbers& dnums =
+      convolution->convolution_dimension_numbers();
+
+  if (rhs->opcode() != HloOpcode::kPad) {
+    return false;
+  }
+
+  // Convolution's padding is always zero, so bail if the kPad is adding
+  // something other than zero.
+  if (!IsAll(rhs->operand(1), 0)) {
+    return false;
+  }
+
+  const auto& padding = rhs->padding_config();
+
+  // Can't pad or dilate feature dims.
+  for (int64 dim : {dnums.kernel_input_feature_dimension(),
+                    dnums.kernel_output_feature_dimension()}) {
+    const auto& p = padding.dimensions(dim);
+    if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
+        p.interior_padding() != 0) {
+      return false;
+    }
+  }
+
+  // Compute the window which is the result of merging the kPad and the
+  // convolution's existing window.
+  Window new_window = convolution->window();
+  for (int64 dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
+    auto& w = *new_window.mutable_dimensions(dim);
+    const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
+
+    // We can only do this transformation if p adds dilation to the filter --
+    // edge padding on the filter is not supported in conv.
+    if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
+      return false;
+    }
+
+    // Nothing to do if the kPad for this dim is entirely a nop.
+    if (p.interior_padding() == 0) {
+      continue;
+    }
+
+    // We cowardly refuse to think about how dilation composes with itself;
+    // bail if both the kPad and conv have dilation on this dimension.
+    if (w.window_dilation() > 1) {
+      return false;
+    }
+    CHECK_EQ(w.window_dilation(), 1);
+    w.set_window_dilation(1 + p.interior_padding());
+    w.set_size(rhs->operand(0)->shape().dimensions(
+        dnums.kernel_spatial_dimensions(dim)));
+  }
+
+  auto new_conv = convolution->CloneWithNewOperands(
+      convolution->shape(), {lhs, rhs->mutable_operand(0)});
+  new_conv->set_window(new_window);
+  TF_RETURN_IF_ERROR(
+      ReplaceWithNewInstruction(convolution, std::move(new_conv)));
+  return true;
+}
+
+StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
+    HloInstruction* convolution) {
+  auto* lhs = convolution->mutable_operand(0);
+  auto* rhs = convolution->mutable_operand(1);
+  const auto& window = convolution->window();
+  const ConvolutionDimensionNumbers& dnums =
+      convolution->convolution_dimension_numbers();
+
   if (!enable_conv_simplification_) {
-    return Status::OK();
+    return false;
   }
-  // HandleConvolution tries to replace a convolution with a DOT instruction.
-  //
-  // Only add when bitcasts can be used:
-  // - if bitcasts are not supported, then reshapes could be used but will
-  //   end up with another copy.
-  // - if bitcasts are supported, the simplifier will be called again with
-  //   bitcasts_ == true.
 
-  // TODO(cwhipkey): b/31337498, make this layout insensitive.
+  // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
+  // layout-insensitive mode, for fear of adding nontrivial reshapes.
   if (!is_layout_sensitive_) {
-    return Status::OK();
+    return false;
   }
 
   const Shape& input_shape = lhs->shape();
@@ -2389,7 +2382,7 @@
   // Require the spatial dimensions in the kernel to have a bound of one.
   for (int64 i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
-      return Status::OK();
+      return false;
     }
   }
 
@@ -2400,7 +2393,7 @@
   // for a 1x1 window, so window dilation is no problem.
   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
       window_util::HasBaseDilation(window)) {
-    return Status::OK();
+    return false;
   }
 
   // Also, the shapes must align for a rowmajor matmul:
@@ -2426,7 +2419,7 @@
                            dnums.kernel_input_feature_dimension()) <
        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
                            dnums.kernel_output_feature_dimension()))) {
-    return Status::OK();
+    return false;
   }
 
   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
@@ -2468,7 +2461,7 @@
   if (!valid_bitcast_callback_(input_shape, new_input_shape) ||
       !valid_bitcast_callback_(filter_shape, new_filter_shape) ||
       !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
-    return Status::OK();
+    return false;
   }
 
   auto new_lhs = add_bitcast(new_input_shape, lhs);
@@ -2477,10 +2470,47 @@
   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
   auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
-      dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
-  dot->set_precision_config(convolution->precision_config());
+      dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
+      convolution->precision_config()));
 
-  return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
+  TF_RETURN_IF_ERROR(
+      ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
+  return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+    HloInstruction* convolution) {
+  // Zero-sized input or filter.
+  if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
+      ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
+    return ReplaceWithNewInstruction(
+        convolution,
+        HloInstruction::CreateBroadcast(
+            convolution->shape(),
+            computation_->AddInstruction(HloInstruction::CreateConstant(
+                LiteralUtil::Zero(convolution->shape().element_type()))),
+            {}));
+  }
+
+  // Try to merge padding/dilation of the input with the convolution's window.
+  TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
+  if (folded_input_pad) {
+    return Status::OK();
+  }
+
+  // Try to merge dilation of the filter with the convolution's window.
+  TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
+  if (folded_filter_pad) {
+    return Status::OK();
+  }
+
+  // Try to replace the convolution with a kDot instruction.
+  TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
+  if (replaced_with_dot) {
+    return Status::OK();
+  }
+
+  return Status::OK();
 }
 
 bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 43a891e..3fc1ba2 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1044,7 +1044,8 @@
   dim->set_window_reversal(false);
   // Create add computation.
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums));
+      ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(builder.Build());
   HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
                                              non_bitcasting_callback());
@@ -2260,9 +2261,11 @@
           .ValueOrDie();
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
-                                         window, dnums)
+                                         /*feature_group_count=*/1, window,
+                                         dnums)
           .ValueOrDie(),
-      lhs_pad, filter, window, dnums));
+      lhs_pad, filter, /*feature_group_count=*/1, window, dnums,
+      DefaultPrecisionConfig(2)));
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
@@ -2366,18 +2369,20 @@
                                               rhs_pad->shape().dimensions(3),
                                               testcase.orig_conv_window))
                       .ValueOrDie();
-  auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
-                                         window, dnums)
-          .ValueOrDie(),
-      input, rhs_pad, window, dnums));
 
   // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
   // after the transformation.
-  PrecisionConfigProto precision_config;
-  precision_config.add_operand_precision(PrecisionConfigProto::HIGH);
-  precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST);
-  orig_conv->set_precision_config(precision_config);
+  PrecisionConfig precision_config;
+  precision_config.add_operand_precision(PrecisionConfig::HIGH);
+  precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
+
+  builder.AddInstruction(HloInstruction::CreateConvolve(
+      ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+                                         /*feature_group_count=*/1, window,
+                                         dnums)
+          .ValueOrDie(),
+      input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+      precision_config));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
@@ -2396,9 +2401,10 @@
                               conv->operand(1)->shape().dimensions(2),
                               conv->operand(1)->shape().dimensions(3),
                               testcase.expected_conv_window));
-    EXPECT_THAT(
-        conv->precision_config().operand_precision(),
-        ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST));
+    EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
+                    ->precision_config()
+                    .operand_precision(),
+                ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
   }
 }
 
@@ -2522,8 +2528,9 @@
     HloInstruction* filter =
         b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
 
-    b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
-                                                    window, dnums));
+    b.AddInstruction(HloInstruction::CreateConvolve(
+        out_shape, input, filter,
+        /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
 
     // TODO(b/80488902): verify this module.
     auto module = HloTestBase::CreateNewModule();
@@ -2901,7 +2908,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums,
+                                                   DefaultPrecisionConfig(2)));
   std::unique_ptr<HloComputation> dot_computation(builder.Build());
 
   HloComputation::Builder call_builder(TestName() + ".Call");
@@ -2924,9 +2932,9 @@
   HloComputation::Builder builder(TestName());
   const float constant_scalar = 7.3f;
   std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
-  std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(constant_scalar).get(),
-       LiteralUtil::CreateR1<float>(constant_vector).get()});
+  Literal elements[] = {LiteralUtil::CreateR0<float>(constant_scalar),
+                        LiteralUtil::CreateR1<float>(constant_vector)};
+  Literal value = LiteralUtil::MakeTuple({&elements[0], &elements[1]});
   builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
 
   auto computation = module().AddEntryComputation(builder.Build());
@@ -3253,8 +3261,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  builder.AddInstruction(
-      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(
+      dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
   auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
                                  non_bitcasting_callback());
@@ -3329,8 +3337,8 @@
   dot_dnums.add_rhs_contracting_dimensions(0);
 
   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
-  builder.AddInstruction(
-      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(
+      dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3393,8 +3401,8 @@
   dot_dnums.add_rhs_contracting_dimensions(0);
 
   Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n});
-  builder.AddInstruction(
-      HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(
+      dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3511,8 +3519,8 @@
   int64 dot_row_size = 1;
   int64 dot_col_size = spec.n;
   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
-  builder.AddInstruction(
-      HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(
+      dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -3581,8 +3589,8 @@
   int64 dot_row_size = spec.m;
   int64 dot_col_size = 1;
   Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size});
-  builder.AddInstruction(
-      HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums));
+  builder.AddInstruction(HloInstruction::CreateDot(
+      dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto computation = module().AddEntryComputation(builder.Build());
   AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
index a16b85a..eda026a 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc
@@ -63,8 +63,8 @@
       new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
 
   TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
-                      MakeDotHlo(new_lhs, new_rhs, new_dim_numbers));
-  new_dot->set_precision_config(batch_dot->precision_config());
+                      MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
+                                 batch_dot->precision_config()));
 
   TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
                       MakeReshapeHlo(batch_dot->shape(), new_dot));
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index ec281ae..30d33e0 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -205,11 +205,11 @@
   const Shape feature_shape = scale->shape();
 
   auto zero_literal = LiteralUtil::CreateR0(0.0f);
-  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
 
   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
-  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
   auto epsilon = add(HloInstruction::CreateBroadcast(
       operand_shape,
       add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
@@ -331,7 +331,7 @@
   const Shape feature_shape = scale->shape();
 
   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
-  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
   auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
       operand_shape,
       computation_->AddInstruction(
@@ -464,11 +464,11 @@
   const int64 elements_per_feature_int64 = size_in_elements / feature_count;
 
   auto zero_literal = LiteralUtil::CreateR0(0.0f);
-  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
+  TF_ASSIGN_OR_RETURN(zero_literal, zero_literal.Convert(ptype));
   auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
 
   auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
-  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
+  TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal.Convert(ptype));
   auto epsilon_scalar =
       add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
   auto epsilon_activation = add(
@@ -560,7 +560,7 @@
   auto elements_per_feature_literal =
       LiteralUtil::CreateR0<float>(elements_per_feature_int64);
   TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
-                      elements_per_feature_literal->Convert(ptype));
+                      elements_per_feature_literal.Convert(ptype));
   auto elements_per_feature = add(
       HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
   auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9b..f7ac8f5 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
 namespace {
 
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
 
 // Test that we expand BatchNormTraining.
 TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@
   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
                              /*rewrite_inference_op=*/true,
                              /*rewrite_grad_op=*/true);
-  ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
   root = computation->root_instruction();
   // Make sure this operation is expanded.
   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@
   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
                              /*rewrite_inference_op=*/true,
                              /*rewrite_grad_op=*/true);
-  ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
   root = computation->root_instruction();
   // Make sure this operation is expanded.
   EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@
     epsilon=0.001, feature_index=1, sharding={maximal device=1}
 })";
 
-  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+  ParseAndVerifyModule(module_str);
   BatchNormExpander rewriter(/*rewrite_training_op=*/true,
                              /*rewrite_inference_op=*/true,
                              /*rewrite_grad_op=*/true);
-  ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+  ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
 
-  for (auto* instruction : module->entry_computation()->instructions()) {
+  for (auto* instruction : module().entry_computation()->instructions()) {
     if (instruction->opcode() == HloOpcode::kParameter) {
       continue;
     }
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 6363a21..5f93740 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -65,8 +65,12 @@
   }
 };
 
-class BFloat16ConversionFoldingTest : public HloTestBase {
+class BFloat16ConversionFoldingTest : public HloVerifiedTestBase {
  protected:
+  BFloat16ConversionFoldingTest()
+      : HloVerifiedTestBase(/*layout_sensitive=*/false,
+                            /*allow_mixed_precision=*/true) {}
+
   bool FoldConversions(HloModule* module) {
     TestBFloat16Support bfloat16_support_;
     BFloat16ConversionFolding fold(&bfloat16_support_);
@@ -102,7 +106,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(FoldConversions(module.get()));
+  EXPECT_TRUE(FoldConversions(module));
 
   EXPECT_EQ(computation->root_instruction(), add1);
   EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -137,7 +141,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(FoldConversions(module.get()));
+  EXPECT_FALSE(FoldConversions(module));
 
   EXPECT_EQ(computation->root_instruction(), convert2);
   EXPECT_EQ(mul0->shape().element_type(), F32);
@@ -172,7 +176,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(FoldConversions(module.get()));
+  EXPECT_FALSE(FoldConversions(module));
 
   EXPECT_EQ(computation->root_instruction(), convert2);
   EXPECT_EQ(sub0->shape().element_type(), F32);
@@ -202,7 +206,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(FoldConversions(module.get()));
+  EXPECT_FALSE(FoldConversions(module));
 
   EXPECT_EQ(computation->root_instruction(), convert1);
   EXPECT_EQ(gte->shape().element_type(), F32);
@@ -248,7 +252,7 @@
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(FoldConversions(module.get()));
+  EXPECT_TRUE(FoldConversions(module));
 
   EXPECT_EQ(computation->root_instruction(), tuple);
   EXPECT_EQ(tuple->operand(0), gte_a);
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index b08705d..cef0eba 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -23,7 +23,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
 namespace xla {
@@ -68,8 +68,12 @@
   }
 };
 
-class BFloat16NormalizationTest : public HloTestBase {
+class BFloat16NormalizationTest : public HloVerifiedTestBase {
  protected:
+  BFloat16NormalizationTest()
+      : HloVerifiedTestBase(/*layout_sensitive=*/false,
+                            /*allow_mixed_precision=*/true) {}
+
   bool Normalize(HloModule* module) {
     TestBFloat16Support bfloat16_support_;
     BFloat16Normalization normalization(&bfloat16_support_);
@@ -105,7 +109,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(Normalize(module.get()));
+  EXPECT_FALSE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction(), add1);
   EXPECT_EQ(add0->shape().element_type(), BF16);
@@ -133,7 +137,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
   EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
@@ -163,7 +167,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
   EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
@@ -201,7 +205,7 @@
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction(), reduce);
   EXPECT_EQ(reduce->called_computations().size(), 1);
@@ -259,7 +263,7 @@
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction(), gte);
   EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -286,7 +290,7 @@
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction(), gte);
   EXPECT_EQ(gte->shape().element_type(), BF16);
@@ -308,13 +312,16 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
   HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums));
+      HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(Normalize(module.get()));
+  EXPECT_TRUE(Normalize(module));
 
   EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
   EXPECT_EQ(dot->shape().element_type(), F32);
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 545a6ec..58f78f8 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -675,10 +675,8 @@
         continue;
       }
       if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
-        TF_ASSIGN_OR_RETURN(
-            auto converted_literal,
-            hlo->literal().ConvertToShape(hlo->shape(),
-                                          /*round_f32_to_bf16=*/true));
+        TF_ASSIGN_OR_RETURN(auto converted_literal,
+                            hlo->literal().ConvertToShape(hlo->shape()));
         auto new_constant = computation->AddInstruction(
             HloInstruction::CreateConstant(std::move(converted_literal)));
         TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 69b654d..e032b5c 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
@@ -55,8 +55,12 @@
   }
 };
 
-class BFloat16PropagationTest : public HloTestBase {
+class BFloat16PropagationTest : public HloVerifiedTestBase {
  protected:
+  BFloat16PropagationTest()
+      : HloVerifiedTestBase(/*layout_sensitive=*/false,
+                            /*allow_mixed_precision=*/true) {}
+
   // Runs the propagation pass on the given module, and returns whether the
   // module is changed after this pass.
   bool PropagatePrecision(HloModule* module) {
@@ -77,6 +81,16 @@
            inst->users()[0]->opcode() == HloOpcode::kConvert &&
            inst->users()[0]->shape().element_type() == BF16;
   }
+
+  std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
+                                            HloInstruction* lhs,
+                                            HloInstruction* rhs) {
+    DotDimensionNumbers dot_dnums;
+    dot_dnums.add_lhs_contracting_dimensions(1);
+    dot_dnums.add_rhs_contracting_dimensions(0);
+    return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+                                     DefaultPrecisionConfig(2));
+  }
 };
 
 // Tests that BF16 can propagate through select over non-tuple buffers, but not
@@ -95,22 +109,22 @@
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
   HloInstruction* add1 = builder.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
-  HloInstruction* pred = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b));
+  HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
   HloInstruction* sel = builder.AddInstruction(
       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
   HloInstruction* xpose =
       builder.AddInstruction(HloInstruction::CreateTranspose(
           ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
-  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a));
-  HloInstruction* root = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+  HloInstruction* dot = builder.AddInstruction(
+      CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
+  HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), root);
   EXPECT_TRUE(OutputsBF16(xpose));
@@ -136,13 +150,12 @@
       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
   HloInstruction* b = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+  HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_TRUE(OutputsBF16(dot->operand(0)));
@@ -150,10 +163,10 @@
   EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
   EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+      LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_a)),
       dot->operand(0)->literal()));
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+      LiteralUtil::ConvertF32ToBF16(LiteralUtil::CreateFromArray(array_b)),
       dot->operand(1)->literal()));
 }
 
@@ -189,8 +202,8 @@
           builder.AddInstruction(HloInstruction::CreateGetTupleElement(
               tuple0->shape(), tuple1, 0)),
           0));
-  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+  HloInstruction* dot = builder.AddInstruction(
+      CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
 
   HloInstruction* output_tuple =
       builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
@@ -198,7 +211,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), output_tuple);
   EXPECT_TRUE(OutputsBF16(xpose));
@@ -231,13 +244,13 @@
       HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
 
   // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
-  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+  HloInstruction* dot = builder.AddInstruction(
+      CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_TRUE(OutputsBF16(add1));
@@ -249,7 +262,7 @@
 // Tests that a non-fusion computation's root should not be changed.
 TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
   auto builder = HloComputation::Builder(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
 
   HloInstruction* a =
       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
@@ -258,8 +271,7 @@
   HloInstruction* add = builder.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
 
-  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add));
+  HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
 
   HloInstruction* tuple =
       builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
@@ -267,7 +279,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(PropagatePrecision(module.get()));
+  EXPECT_FALSE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), tuple);
   EXPECT_FALSE(OutputsBF16(add));
@@ -277,7 +289,7 @@
 TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
   auto module = CreateNewModule();
   auto builder = HloComputation::Builder(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
 
   HloInstruction* param = builder.AddInstruction(
       HloInstruction::CreateParameter(0, shape, "param"));
@@ -303,15 +315,14 @@
       HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
   HloInstruction* b_f1 = builder_f1.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
-  HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1));
+  HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
   auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
   auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
       dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), fusion1);
   EXPECT_TRUE(OutputsBF16(add));
@@ -326,7 +337,7 @@
 TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
   auto module = CreateNewModule();
   auto builder = HloComputation::Builder(TestName());
-  Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+  Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
 
   HloInstruction* param = builder.AddInstruction(
       HloInstruction::CreateParameter(0, shape, "param"));
@@ -340,15 +351,15 @@
       builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
   HloInstruction* add_f = builder_f.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
-  HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
+  HloInstruction* dot_f =
+      builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
   auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
   auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
       dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(PropagatePrecision(module.get()));
+  EXPECT_FALSE(PropagatePrecision(module));
   EXPECT_EQ(computation->root_instruction(), fusion);
 }
 
@@ -390,12 +401,11 @@
       HloInstruction::CreateGetTupleElement(shape, fusion, 0));
   HloInstruction* gte1 = builder.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, fusion, 1));
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1));
+  HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_TRUE(OutputsBF16(gte0));
@@ -440,12 +450,12 @@
   HloInstruction* xpose =
       builder.AddInstruction(HloInstruction::CreateTranspose(
           ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
-  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1));
+  HloInstruction* dot = builder.AddInstruction(
+      CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
 
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_FALSE(OutputsBF16(add0));
@@ -472,31 +482,36 @@
   auto builder_cond = HloComputation::Builder("cond");
   auto cond_param = builder_cond.AddInstruction(
       HloInstruction::CreateParameter(0, shape, "cond_param"));
-  auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, cond_param, cond_param));
+  auto cond_dot =
+      builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
   auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(
+              HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+                                          cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+              {1, 1}))))));
   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
 
   auto builder_body = HloComputation::Builder("body");
   auto body_param = builder_body.AddInstruction(
       HloInstruction::CreateParameter(0, shape, "body_param"));
-  auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, body_param, body_param));
+  auto body_dot =
+      builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
   auto body = module->AddEmbeddedComputation(builder_body.Build());
 
   auto while_hlo = builder.AddInstruction(
       HloInstruction::CreateWhile(shape, cond, body, add));
 
-  auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, while_hlo, while_hlo));
+  auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_TRUE(
@@ -528,10 +543,16 @@
       HloInstruction::CreateParameter(0, shape, "cond_param"));
   builder_cond.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})),
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1}))));
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
+              {1, 1})))),
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
+              {1, 1}))))));
   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
 
   auto builder_body = HloComputation::Builder("body");
@@ -552,11 +573,10 @@
   auto while_hlo = builder.AddInstruction(
       HloInstruction::CreateWhile(shape, cond, body, add));
 
-  auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, while_hlo, while_hlo));
+  auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_FALSE(PropagatePrecision(module.get()));
+  EXPECT_FALSE(PropagatePrecision(module));
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_FALSE(OutputsBF16(add));
   EXPECT_FALSE(OutputsBF16(body_fusion));
@@ -593,14 +613,20 @@
   // This add should prevent RHS from using BF16
   auto cond_add_rhs = builder_cond.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
-  auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+  auto cond_dot =
+      builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
   builder_cond.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
-      builder_cond.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(
+              HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+                                          cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+      builder_cond.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+              {1, 1}))))));
   auto cond = module->AddEmbeddedComputation(builder_cond.Build());
 
   auto builder_body = HloComputation::Builder("body");
@@ -610,10 +636,10 @@
       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
   auto body_rhs = builder_body.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
-  auto body_dot1 = builder_body.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
-  auto body_dot2 = builder_body.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs));
+  auto body_dot1 =
+      builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
+  auto body_dot2 =
+      builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
   auto body_transpose = builder_body.AddInstruction(
       HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
   builder_body.AddInstruction(
@@ -627,11 +653,10 @@
       HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
   auto rhs = builder.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+  auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), dot);
   EXPECT_TRUE(OutputsBF16(lhs));
@@ -683,14 +708,20 @@
   auto cond0_add_rhs =
       builder_cond0.AddInstruction(HloInstruction::CreateBinary(
           shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
-  auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+  auto cond0_dot =
+      builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
   builder_cond0.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
-      builder_cond0.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
-      builder_cond0.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+      builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond0.AddInstruction(
+              HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+                                          cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
+      builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
+              {1, 1}))))));
   auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
 
   // Condition computation for the second while.
@@ -705,14 +736,20 @@
   auto cond1_add_lhs =
       builder_cond1.AddInstruction(HloInstruction::CreateBinary(
           shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
-  auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+  auto cond1_dot =
+      builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
   builder_cond1.AddInstruction(HloInstruction::CreateBinary(
       ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
-      builder_cond1.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
-      builder_cond1.AddInstruction(HloInstruction::CreateSlice(
-          ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+      builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond1.AddInstruction(
+              HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+                                          cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
+      builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+          ShapeUtil::MakeShape(F32, {}),
+          builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+              ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
+              {1, 1}))))));
   auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
 
   // Body computation shared by both whiles.
@@ -723,8 +760,8 @@
       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
   auto body_rhs = builder_body.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, body_param, 1));
-  auto body_dot = builder_body.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+  auto body_dot =
+      builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
   builder_body.AddInstruction(
       HloInstruction::CreateTuple({body_dot, body_rhs}));
   auto body = module->AddEmbeddedComputation(builder_body.Build());
@@ -734,23 +771,22 @@
   auto while1 = builder.AddInstruction(
       HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
 
-  auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot,
-      builder.AddInstruction(
-          HloInstruction::CreateGetTupleElement(shape, while0, 0)),
-      builder.AddInstruction(
-          HloInstruction::CreateGetTupleElement(shape, while0, 1))));
-  auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kDot,
-      builder.AddInstruction(
-          HloInstruction::CreateGetTupleElement(shape, while1, 0)),
-      builder.AddInstruction(
-          HloInstruction::CreateGetTupleElement(shape, while1, 1))));
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+  auto lhs = builder.AddInstruction(
+      CreateDot(shape,
+                builder.AddInstruction(
+                    HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+                builder.AddInstruction(
+                    HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+  auto rhs = builder.AddInstruction(
+      CreateDot(shape,
+                builder.AddInstruction(
+                    HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+                builder.AddInstruction(
+                    HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+  auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
   EXPECT_FALSE(OutputsBF16(body_dot));
   EXPECT_FALSE(OutputsBF16(body_rhs));
   EXPECT_FALSE(OutputsBF16(body_lhs));
@@ -792,7 +828,7 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), add2);
   EXPECT_EQ(add2->operand(0), add0);
@@ -821,15 +857,14 @@
       HloInstruction::CreateGetTupleElement(shape, domain, 0));
   HloInstruction* b_gte = builder.AddInstruction(
       HloInstruction::CreateGetTupleElement(shape, domain, 1));
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte));
+  HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
   HloInstruction* root = builder.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
   EXPECT_EQ(computation->root_instruction(), root);
 
   // test BF16 propagated through domain
@@ -867,15 +902,15 @@
       HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
   HloInstruction* b_trans = builder.AddInstruction(
       HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+  HloInstruction* dot =
+      builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
   HloInstruction* root = builder.AddInstruction(
       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
 
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  EXPECT_TRUE(PropagatePrecision(module.get()));
+  EXPECT_TRUE(PropagatePrecision(module));
 
   EXPECT_EQ(computation->root_instruction(), root);
   EXPECT_TRUE(OutputsBF16(a_trans));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 8b8c6bf..65fa951 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -30,7 +30,6 @@
 #include "tensorflow/compiler/xla/service/heap_simulator.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -617,18 +616,24 @@
   }
 
   // Only compute total fragmentation if all computations have schedules.
-  SequentialHloOrdering::HloModuleSequence module_sequence;
+  HloSchedule schedule(module_);
+  bool schedule_complete = true;
   for (const auto& computation : module_->computations()) {
-    const std::vector<const HloInstruction*>* sequence =
-        liveness_->hlo_ordering().SequentialOrder(*computation);
-    if (sequence != nullptr) {
-      module_sequence.emplace(computation, *sequence);
+    if (!computation->IsFusionComputation()) {
+      const std::vector<const HloInstruction*>* sequence =
+          liveness_->hlo_ordering().SequentialOrder(*computation);
+      if (sequence == nullptr) {
+        schedule_complete = false;
+      } else {
+        schedule.set_sequence(computation, *sequence);
+      }
     }
   }
-  if (module_sequence.size() == module_->computation_count()) {
+  if (schedule_complete) {
+    TF_RETURN_IF_ERROR(schedule.Verify());
     TF_ASSIGN_OR_RETURN(
         const int64 min_size,
-        HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
+        HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_));
     stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
   }
 
@@ -1064,7 +1069,7 @@
     // since buffers for kCall, kWhile, and kConditional sub-computations are
     // only live for the duration of their calling instructions.
     VLOG(1) << "Running whole-module heap simulation";
-    SequentialHloOrdering::HloModuleSequence module_sequence;
+    HloSchedule schedule(&assignment->module());
     FlatSet<const LogicalBuffer*> all_buffers_to_assign;
     for (const auto& pair : buffers_to_assign_sequentially) {
       const HloComputation* computation = pair.first;
@@ -1072,7 +1077,7 @@
       const std::vector<const HloInstruction*>* instruction_sequence =
           hlo_ordering.SequentialOrder(*computation);
       CHECK(instruction_sequence != nullptr) << computation->name();
-      module_sequence[computation] = *instruction_sequence;
+      schedule.set_sequence(computation, *instruction_sequence);
       all_buffers_to_assign.insert(buffers_to_assign.begin(),
                                    buffers_to_assign.end());
     }
@@ -1090,7 +1095,7 @@
           const HeapSimulator::Result result,
           HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
                                  absl::make_unique<LazyBestFitHeap>(alignment)),
-                             assignment->module(), module_sequence,
+                             assignment->module(), schedule,
                              assignment->points_to_analysis(),
                              assignment->buffer_size_, options));
       AssignBuffersFromHeapSimulator(result, assignment,
@@ -1121,7 +1126,7 @@
             HeapSimulator::Run(
                 absl::make_unique<DecreasingSizeRunsHeap>(
                     absl::make_unique<LazyBestFitHeap>(alignment)),
-                *computation, *instruction_sequence,
+                *computation, HloInstructionSequence(*instruction_sequence),
                 assignment->points_to_analysis(), assignment->buffer_size_,
                 options));
         AssignBuffersFromHeapSimulator(result, assignment,
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8bd1533..795beb9 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -30,16 +30,18 @@
 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/macros.h"
 
 namespace xla {
@@ -120,14 +122,10 @@
       HloModule* module,
       absl::Span<const HloInstruction* const> instruction_sequence,
       int64 alignment = 1) {
-    SequentialHloOrdering::HloModuleSequence module_sequence;
-    module_sequence[module->entry_computation()] =
-        std::vector<const HloInstruction*>(instruction_sequence.begin(),
-                                           instruction_sequence.end());
+    HloSchedule schedule(module);
+    schedule.set_sequence(module->entry_computation(), instruction_sequence);
     return BufferAssigner::Run(
-               module,
-               absl::make_unique<SequentialHloOrdering>(module,
-                                                        module_sequence),
+               module, absl::make_unique<SequentialHloOrdering>(schedule),
                backend().compiler()->BufferSizeBytesFunction(),
                [alignment](LogicalBuffer::Color) { return alignment; },
                /*allow_input_output_aliasing=*/false,
@@ -1247,9 +1245,10 @@
   // Test that a tuple constant which is forwarded to the computation output
   // is properly handled.
   auto builder = HloComputation::Builder(TestName());
+  Literal elements[] = {LiteralUtil::CreateR0<int64>(0),
+                        LiteralUtil::CreateR0<int64>(1)};
   builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
-                              LiteralUtil::CreateR0<int64>(1).get()})));
+      LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
@@ -1490,10 +1489,13 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot_ab = builder.AddInstruction(
-      HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
-  auto dot_bc = builder.AddInstruction(
-      HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
+  auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot(
+      shape_2x4, param_a, param_b, dot_dnums, precision_config));
+  auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot(
+      shape_3x4, param_b, param_c, dot_dnums, precision_config));
   builder.AddInstruction(
       HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0));
 
@@ -1782,11 +1784,10 @@
 
   std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
                                                         int64 alignment = 1) {
-    auto sequence =
-        ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+    HloSchedule schedule =
+        ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
     return BufferAssigner::Run(
-               module,
-               absl::make_unique<SequentialHloOrdering>(module, sequence),
+               module, absl::make_unique<SequentialHloOrdering>(schedule),
                ByteSizeOf,
                [alignment](LogicalBuffer::Color) { return alignment; },
                /*allow_input_output_aliasing=*/false,
@@ -2093,17 +2094,25 @@
   // Create a sequential order among all the instructions in the entry
   // computation, since the issue this test stresses depends on the order the
   // nodes are traversed during BufferAssignment.
-  SequentialHloOrdering::HloModuleSequence sequence;
-  sequence[module->entry_computation()] = {
-      token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape(),
+                                     /*pointer_size=*/sizeof(void*));
+      }));
+  schedule.set_sequence(
+      module->entry_computation(),
+      {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple});
+  TF_ASSERT_OK(schedule.Verify());
+
   TF_ASSERT_OK_AND_ASSIGN(
       auto assignment,
-      BufferAssigner::Run(
-          module, absl::make_unique<SequentialHloOrdering>(module, sequence),
-          backend().compiler()->BufferSizeBytesFunction(),
-          [](LogicalBuffer::Color) { return 1; },
-          /*allow_input_output_aliasing=*/false,
-          /*allocate_buffers_for_constants=*/true));
+      BufferAssigner::Run(module,
+                          absl::make_unique<SequentialHloOrdering>(schedule),
+                          backend().compiler()->BufferSizeBytesFunction(),
+                          [](LogicalBuffer::Color) { return 1; },
+                          /*allow_input_output_aliasing=*/false,
+                          /*allocate_buffers_for_constants=*/true));
 
   // The result tuple elements must be assigned with different buffers.
   TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0}));
@@ -2260,29 +2269,6 @@
             GetAllocation(*buffers, param0, {1, 1}));
 }
 
-static bool IsPostOrderTraversal(
-    const std::vector<const HloInstruction*>& sequence) {
-  tensorflow::gtl::FlatSet<const HloInstruction*> seen_so_far;
-  auto has_not_been_seen_yet = [&](const HloInstruction* instruction) {
-    return seen_so_far.count(instruction) == 0;
-  };
-
-  for (auto instruction : sequence) {
-    if (std::any_of(instruction->operands().begin(),
-                    instruction->operands().end(), has_not_been_seen_yet) ||
-        std::any_of(instruction->control_predecessors().begin(),
-                    instruction->control_predecessors().end(),
-                    has_not_been_seen_yet)) {
-      return false;  // Not a post order.
-    }
-    if (!seen_so_far.insert(instruction).second) {
-      return false;  // Not a "traversal".
-    }
-  }
-
-  return true;
-}
-
 TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
   auto module = CreateNewModule();
   auto builder = HloComputation::Builder(TestName());
@@ -2337,27 +2323,27 @@
 
   RunCopyInsertion(module);
 
-  auto sequence =
-      ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
+  HloSchedule schedule =
+      ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie();
 
-  // To trigger b/38494731, we want a specific Hlo sequence for the
+  // To trigger b/38494731, we want a specific Hlo schedule for the
   // root computation, so we overwrite that entry with a manually
   // crafted sequence.
-  sequence[module->entry_computation()] = {
-      input1, weights1, one,     output1, while1->operand(0), while1,
-      input0, weights0, zero,    output0, while0->operand(0), while0,
-      gte0,   gte1,     root_add};
+  schedule.set_sequence(module->entry_computation(),
+                        {input1, weights1, one, output1, while1->operand(0),
+                         while1, input0, weights0, zero, output0,
+                         while0->operand(0), while0, gte0, gte1, root_add});
 
-  // If this ASSERT_TRUE fails, we constructed a bogus sequence above
-  // and this test itself is buggy.
-  ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+  // If this ASSERT fails, we constructed a bogus sequence above and this test
+  // itself is buggy.
+  TF_ASSERT_OK(schedule.Verify());
 
   auto assignment =
-      BufferAssigner::Run(
-          module, absl::make_unique<SequentialHloOrdering>(module, sequence),
-          ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
-          /*allow_input_output_aliasing=*/false,
-          /*allocate_buffers_for_constants=*/true)
+      BufferAssigner::Run(module,
+                          absl::make_unique<SequentialHloOrdering>(schedule),
+                          ByteSizeOf, [](LogicalBuffer::Color) { return 1; },
+                          /*allow_input_output_aliasing=*/false,
+                          /*allocate_buffers_for_constants=*/true)
           .ConsumeValueOrDie();
 
   EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index 26e26e3..17e5090 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -27,6 +27,7 @@
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 
 namespace xla {
 namespace {
@@ -166,12 +167,12 @@
   auto module = CreateNewModule();
   HloComputation* entry = module->AddEntryComputation(builder.Build());
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  sequence.insert({entry, {param0, negate, param1, exp, add}});
-  auto liveness = BufferLiveness::Run(module.get(),
-                                      absl::make_unique<SequentialHloOrdering>(
-                                          module.get(), sequence))
-                      .ConsumeValueOrDie();
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+  auto liveness =
+      BufferLiveness::Run(module.get(),
+                          absl::make_unique<SequentialHloOrdering>(schedule))
+          .ConsumeValueOrDie();
 
   // Entry parameters interfere as if they are defined simultaneously at
   // the very beginning.
@@ -291,13 +292,12 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build());
 
-  SequentialHloOrdering::HloModuleSequence module_sequence;
-  std::vector<const HloInstruction*> order = {param, negate, exp, add};
-  module_sequence.emplace(computation, order);
-  auto liveness = BufferLiveness::Run(module.get(),
-                                      absl::make_unique<SequentialHloOrdering>(
-                                          module.get(), module_sequence))
-                      .ConsumeValueOrDie();
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation, {param, negate, exp, add});
+  auto liveness =
+      BufferLiveness::Run(module.get(),
+                          absl::make_unique<SequentialHloOrdering>(schedule))
+          .ConsumeValueOrDie();
 
   EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@@ -339,14 +339,14 @@
   auto module = CreateNewModule();
   auto computation = module->AddEntryComputation(builder.Build(add));
 
-  SequentialHloOrdering::HloModuleSequence module_sequence;
-  std::vector<const HloInstruction*> order = {param,     add,  recv,
-                                              recv_done, send, send_done};
-  module_sequence.emplace(computation, order);
-  auto liveness = BufferLiveness::Run(module.get(),
-                                      absl::make_unique<SequentialHloOrdering>(
-                                          module.get(), module_sequence))
-                      .ConsumeValueOrDie();
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(computation,
+                        {param, add, token, recv, recv_done, send, send_done});
+  TF_ASSERT_OK(schedule.Verify());
+  auto liveness =
+      BufferLiveness::Run(module.get(),
+                          absl::make_unique<SequentialHloOrdering>(schedule))
+          .ConsumeValueOrDie();
 
   EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
   // Check the root instruction (add) buffer interferes with the recv buffer.
@@ -440,15 +440,15 @@
   // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
   // the buffers containing {3} and 3 are dead.
   auto builder = HloComputation::Builder(TestName());
-  auto inner_tuple0 =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
-                              LiteralUtil::CreateR0<int64>(1).get()});
-  auto inner_tuple1 =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+  Literal elements0[] = {LiteralUtil::CreateR0<int64>(0),
+                         LiteralUtil::CreateR0<int64>(1)};
+  auto inner_tuple0 = LiteralUtil::MakeTuple({&elements0[0], &elements0[1]});
+  Literal element1 = LiteralUtil::CreateR0<int64>(3);
+  auto inner_tuple1 = LiteralUtil::MakeTuple({&element1});
   auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+      LiteralUtil::MakeTuple({&inner_tuple0, &inner_tuple1})));
   builder.AddInstruction(HloInstruction::CreateGetTupleElement(
-      inner_tuple0->shape(), tuple_constant, 0));
+      inner_tuple0.shape(), tuple_constant, 0));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index cc80b74..34f3f91 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -31,7 +31,7 @@
 
 using ::testing::UnorderedElementsAre;
 
-class CallGraphTest : public HloTestBase {
+class CallGraphTest : public HloVerifiedTestBase {
  protected:
   // Build and return a trivial computation taking and returning a scalar.
   std::unique_ptr<HloComputation> MakeScalarComputation(
@@ -96,7 +96,7 @@
   auto module = CreateNewModule();
   HloComputation* computation =
       module->AddEntryComputation(MakeScalarComputation());
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(1, call_graph->nodes().size());
   EXPECT_TRUE(call_graph->IsFlattened());
 
@@ -118,7 +118,7 @@
   HloComputation* unreachable_computation =
       module->AddEmbeddedComputation(MakeScalarComputation());
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(2, call_graph->nodes().size());
 
   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -140,7 +140,7 @@
   HloComputation* entry_computation = module->AddEntryComputation(
       MakeMappingComputation(map_computation, /*callsites=*/5));
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(2, call_graph->nodes().size());
 
   const CallGraphNode& entry_node = call_graph->GetNode(entry_computation);
@@ -169,7 +169,7 @@
   HloComputation* entry_computation = module->AddEntryComputation(
       MakeCallingComputation(called_computation, /*callsites=*/3));
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(2, call_graph->nodes().size());
 
   // The called computation is only called from one other computation, but there
@@ -210,7 +210,7 @@
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(2, call_graph->nodes().size());
 
   EXPECT_FALSE(call_graph->IsFlattened());
@@ -259,7 +259,7 @@
   HloComputation* entry_computation =
       module->AddEntryComputation(builder.Build());
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
 
   EXPECT_EQ(3, call_graph->nodes().size());
 
@@ -328,7 +328,7 @@
     entry_computation = module->AddEntryComputation(builder.Build());
   }
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(5, call_graph->nodes().size());
   EXPECT_FALSE(call_graph->IsFlattened());
 
@@ -452,7 +452,7 @@
     entry_computation = module->AddEntryComputation(builder.Build());
   }
 
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(5, call_graph->nodes().size());
 
   // Verify NearestAncestorsInSameComputation for various instructions in the
@@ -482,7 +482,7 @@
   auto module = CreateNewModule();
   HloComputation* computation =
       module->AddEntryComputation(MakeScalarComputation());
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
 
   std::vector<HloComputation*> visited;
   TF_ASSERT_OK(call_graph->VisitNodes([&visited](const CallGraphNode& node) {
@@ -499,7 +499,7 @@
       module->AddEntryComputation(MakeScalarComputation());
   HloComputation* unreachable_computation =
       module->AddEmbeddedComputation(MakeScalarComputation());
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
 
   // Test visitation of only reachable nodes.
   {
@@ -533,7 +533,7 @@
   // Test that the call graph visitor properly propagates errors.
   auto module = CreateNewModule();
   module->AddEntryComputation(MakeScalarComputation());
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
 
   Status status = call_graph->VisitNodes(
       [](const CallGraphNode&) { return InternalError("Visitation failed"); });
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f..e6b5665 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@
 
 // Tests for call inlining that are most tractable at the HLO level (vs
 // ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
 
 TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
   // "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@
   auto computation = module->AddEntryComputation(outer.Build());
 
   CallInliner call_inliner;
-  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
   ASSERT_TRUE(mutated);
   EXPECT_THAT(computation->root_instruction(), op::Constant());
   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@
 
   HloComputation::Builder call_false_builder(TestName() + ".call_false");
   call_false_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, pred, "param"));
+  call_false_builder.AddInstruction(
       HloInstruction::CreateCall(pred, {}, false_computation));
   HloComputation* call_false =
       module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@
   auto computation = module->AddEntryComputation(outer.Build());
 
   CallInliner call_inliner;
-  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
   ASSERT_TRUE(mutated);
   EXPECT_THAT(
       computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@
   module->AddEntryComputation(outer.Build());
 
   CallInliner call_inliner;
-  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
   ASSERT_TRUE(mutated);
 }
 
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 9c81a86..0ac4a65 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -214,8 +214,8 @@
     expanded_filter = add(HloInstruction::CreateConcatenate(
         expanded_filter_shape, concat_operands, input_feature_dim));
   }
-  auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>(
-      LiteralUtil::Zero(expanded_filter_shape.element_type()))));
+  auto zero = add(HloInstruction::CreateConstant(
+      LiteralUtil::Zero(expanded_filter_shape.element_type())));
   auto zero_filter =
       add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
   auto new_filter = add(
@@ -223,8 +223,8 @@
                                     filter_mask, expanded_filter, zero_filter));
   auto new_convolution = HloInstruction::CreateConvolve(
       convolution->shape(), convolution->mutable_operand(0), new_filter,
-      convolution->window(), dim_numbers, /*feature_group_count=*/1);
-  new_convolution->set_precision_config(convolution->precision_config());
+      /*feature_group_count=*/1, convolution->window(), dim_numbers,
+      convolution->precision_config());
   TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
       convolution, std::move(new_convolution)));
   return Status::OK();
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index d412578..8cc522a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -122,7 +122,7 @@
         "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
         "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xla/service:hlo_proto_util",
-        "//tensorflow/compiler/xla/service:hlo_scheduling",
+        "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
         "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
         "//tensorflow/compiler/xla/service:hlo_verifier",
         "//tensorflow/compiler/xla/service:indexed_array_analysis",
@@ -670,6 +670,7 @@
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/service:transpose_folding",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/strings",
@@ -800,6 +801,7 @@
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
 )
@@ -821,6 +823,7 @@
         "//tensorflow/compiler/xla:test_helpers",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
 )
@@ -945,6 +948,7 @@
         "//tensorflow/compiler/xla/service:hlo_graph_dumper",
         "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:test",
     ],
@@ -970,6 +974,7 @@
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
index 098ce17..2d99784 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -130,9 +130,9 @@
       // change the dimension mapping but not the dimension sizes. For
       // example, input height and width are the same as before the reshapes.
       HloInstruction* new_conv = module->entry_computation()->AddInstruction(
-          HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
-                                         hlo->window(), new_dnums));
-      new_conv->set_precision_config(hlo->precision_config());
+          HloInstruction::CreateConvolve(
+              new_conv_shape, new_input, new_kernel, hlo->feature_group_count(),
+              hlo->window(), new_dnums, hlo->precision_config()));
 
       // Reshape the output back to the shape of the original convolution.
       TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 547d4c6..2083f44 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/util.h"
 
 #include "tensorflow/compiler/xla/test_helpers.h"
@@ -32,7 +32,7 @@
 
 using ::testing::ElementsAre;
 
-class ConvCanonicalizationTest : public HloTestBase {
+class ConvCanonicalizationTest : public HloVerifiedTestBase {
  public:
   ConvCanonicalizationTest() {
     for (int i = 0; i < 2; ++i) {
@@ -84,7 +84,8 @@
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(
           F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
-      input, kernel, conv_window_, dnums));
+      input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+      DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -95,7 +96,7 @@
         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
       });
   ConvCanonicalization conv_canonicalization(&target_machine_features);
-  EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(conv_canonicalization.Run(module).ValueOrDie());
 
   const HloInstruction* output_reshape = entry_computation->root_instruction();
   EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
@@ -146,7 +147,8 @@
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(
           F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
-      input, kernel, conv_window_, dnums));
+      input, kernel, /*feature_group_count=*/1, conv_window_, dnums,
+      DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
@@ -156,7 +158,7 @@
         return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
       });
   ConvCanonicalization conv_canonicalization(&target_machine_features);
-  EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(conv_canonicalization.Run(module).ValueOrDie());
 }
 
 }  // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 796f365..18fc144 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -77,12 +77,12 @@
 #include "tensorflow/compiler/xla/service/hlo_dce.h"
 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
@@ -584,16 +584,14 @@
   // computation. Using this sequence enables tighter buffer liveness analysis
   // and reduced memory usage (as compared to using DependencyHloOrdering).
   TF_ASSIGN_OR_RETURN(
-      SequentialHloOrdering::HloModuleSequence module_sequence,
-      ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
-                                   DFSMemoryScheduler));
+      HloSchedule schedule,
+      ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler));
 
   // Run buffer allocation on the HLO graph.
   TF_ASSIGN_OR_RETURN(
       std::unique_ptr<BufferAssignment> assignment,
       BufferAssigner::Run(module.get(),
-                          absl::make_unique<SequentialHloOrdering>(
-                              module.get(), module_sequence),
+                          absl::make_unique<SequentialHloOrdering>(schedule),
                           BufferSizeBytesFunction(), memory_alignment,
                           /*allow_input_output_aliasing=*/false,
                           /*allocate_buffers_for_constants=*/true));
@@ -627,9 +625,10 @@
     }
     TF_RETURN_IF_ERROR(
         ir_emitter
-            .EmitComputation(embedded_computation, embedded_computation->name(),
-                             /*is_top_level_computation=*/false,
-                             &module_sequence.at(embedded_computation))
+            .EmitComputation(
+                embedded_computation, embedded_computation->name(),
+                /*is_top_level_computation=*/false,
+                &schedule.sequence(embedded_computation).instructions())
             .status());
   }
   string function_name_prefix = entry_computation->name().empty()
@@ -637,9 +636,10 @@
                                     : entry_computation->name();
   TF_ASSIGN_OR_RETURN(
       llvm::Function * entry_function,
-      ir_emitter.EmitComputation(entry_computation, function_name_prefix,
-                                 /*is_top_level_computation=*/true,
-                                 &module_sequence.at(entry_computation)));
+      ir_emitter.EmitComputation(
+          entry_computation, function_name_prefix,
+          /*is_top_level_computation=*/true,
+          &schedule.sequence(entry_computation).instructions()));
 
   string function_name = [&]() {
     llvm::SmallVector<char, 40> function_name_vector;
@@ -771,20 +771,18 @@
     VLOG(2) << "After optimization:";
     XLA_VLOG_LINES(2, module->ToString());
 
-    TF_ASSIGN_OR_RETURN(
-        SequentialHloOrdering::HloModuleSequence module_sequence,
-        ScheduleComputationsInModule(*module, BufferSizeBytesFunction()));
+    TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+                        ScheduleModule(*module, BufferSizeBytesFunction()));
 
     // Run buffer analysis on the HLO graph. This analysis figures out which
     // temporary buffers are required to run the computation.
     TF_ASSIGN_OR_RETURN(
         std::unique_ptr<BufferAssignment> assignment,
-        BufferAssigner::Run(
-            module,
-            absl::make_unique<SequentialHloOrdering>(module, module_sequence),
-            BufferSizeBytesFunction(), memory_alignment,
-            /*allow_input_output_aliasing=*/false,
-            /*allocate_buffers_for_constants=*/true));
+        BufferAssigner::Run(module,
+                            absl::make_unique<SequentialHloOrdering>(schedule),
+                            BufferSizeBytesFunction(), memory_alignment,
+                            /*allow_input_output_aliasing=*/false,
+                            /*allocate_buffers_for_constants=*/true));
     // BufferAssignment::ToString() includes a header, so no need for us to
     // print one ourselves.
     XLA_VLOG_LINES(2, assignment->ToString());
@@ -824,18 +822,18 @@
       }
       TF_RETURN_IF_ERROR(
           ir_emitter
-              .EmitComputation(embedded_computation,
-                               embedded_computation->name(),
-                               /*is_top_level_computation=*/false,
-                               &module_sequence.at(embedded_computation))
+              .EmitComputation(
+                  embedded_computation, embedded_computation->name(),
+                  /*is_top_level_computation=*/false,
+                  &schedule.sequence(embedded_computation).instructions())
               .status());
     }
     const string& entry_point_name = options.entry_point_name();
-    TF_ASSIGN_OR_RETURN(
-        llvm::Function * entry_function,
-        ir_emitter.EmitComputation(computation, entry_point_name,
-                                   /*is_top_level_computation=*/true,
-                                   &module_sequence.at(computation)));
+    TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
+                        ir_emitter.EmitComputation(
+                            computation, entry_point_name,
+                            /*is_top_level_computation=*/true,
+                            &schedule.sequence(computation).instructions()));
 
     CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name));
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index 4db7fa4..c9fb34b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -25,7 +25,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
@@ -52,7 +52,7 @@
   return count;
 }
 
-class CpuCopyInsertionTest : public HloTestBase {
+class CpuCopyInsertionTest : public HloVerifiedTestBase {
  protected:
   void InsertCopies(HloModule* module) {
     CpuCopyInsertion copy_insertion;
@@ -90,7 +90,7 @@
 
   module->AddEntryComputation(builder.Build());
 
-  InsertCopies(module.get());
+  InsertCopies(module);
 
   EXPECT_EQ(CountCopies(*module), 3);
 
@@ -127,7 +127,7 @@
 
   module->AddEntryComputation(builder.Build());
 
-  InsertCopies(module.get());
+  InsertCopies(module);
 
   EXPECT_EQ(CountCopies(*subcomputation), 2);
   EXPECT_THAT(subcomputation->root_instruction(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
index 0f463e6..be1208f 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@
 #include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/core/lib/core/error_codes.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 
@@ -25,7 +25,7 @@
 
 using ::testing::HasSubstr;
 
-class CpuHloSupportCheckerTest : public HloTestBase {
+class CpuHloSupportCheckerTest : public HloVerifiedTestBase {
  protected:
   CpuHloSupportChecker& checker() { return checker_; }
 
@@ -45,7 +45,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  TF_ASSERT_OK(checker().Run(module.get()).status());
+  TF_ASSERT_OK(checker().Run(module).status());
 }
 
 TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  Status status = checker().Run(module.get()).status();
+  Status status = checker().Run(module).status();
   ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
   EXPECT_THAT(status.error_message(),
               HasSubstr("CPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 284929c..7d99b91 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/service/transpose_folding.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 
 namespace op = xla::testing::opcode_matchers;
 
@@ -38,7 +39,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
+  return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+                                   precision_config);
 }
 
 TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
@@ -692,8 +697,8 @@
   auto* addend = builder.AddInstruction(
       HloInstruction::CreateParameter(2, dot_shape, "param2"));
 
-  auto* dot = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+  auto* dot =
+      builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
   builder.AddInstruction(
       HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 9363af3..4668f38 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -70,7 +70,7 @@
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
   auto result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+      CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -107,9 +107,9 @@
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
   auto dot_a_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+      CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
   auto dot_b_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
+      CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
   builder.AddInstruction(HloInstruction::CreateBinary(
       result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
 
@@ -151,9 +151,9 @@
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
   auto dot_a_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+      CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
   auto dot_b_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
+      CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
   auto tuple_result = builder.AddInstruction(
       HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
 
@@ -189,7 +189,7 @@
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateParameter(0, rhs_shape, "param0"));
   auto dot_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+      CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -229,7 +229,7 @@
   auto dot_rhs = builder.AddInstruction(
       HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
   auto dot_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+      CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
 
   auto module = CreateNewModule();
   HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -276,8 +276,8 @@
       HloInstruction::CreateParameter(1, dot_shape, "param1"));
   HloInstruction* dot_rhs = builder.AddInstruction(
       HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
-  HloInstruction* dot_result = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+  HloInstruction* dot_result =
+      builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
   HloInstruction* add_result;
   if (dot_operand_idx_in_add == 0) {
     add_result = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index e5cf15c..df8c2a6 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -110,7 +110,7 @@
 StatusOr<llvm::Function*> IrEmitter::EmitComputation(
     HloComputation* computation, const string& function_name_prefix,
     bool is_top_level_computation,
-    std::vector<const HloInstruction*>* instruction_order) {
+    const std::vector<const HloInstruction*>* instruction_order) {
   string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
   VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
           << "]; ordered? " << (instruction_order != nullptr);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 58a333b..3df9946 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -98,7 +98,7 @@
   StatusOr<llvm::Function*> EmitComputation(
       HloComputation* computation, const string& function_name_prefix,
       bool is_top_level_computation,
-      std::vector<const HloInstruction*>* instruction_order);
+      const std::vector<const HloInstruction*>* instruction_order);
 
   llvm::IRBuilder<>* b() { return &b_; }
 
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index a84ee78..fad7633 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -35,9 +35,7 @@
   cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
 
   ParallelTaskAssignmentTest()
-      : HloVerifiedTestBase(/*layout_sensitive=*/false,
-                            /*allow_mixed_precision=*/false),
-        target_machine_features_([](int64 shape_size) {
+      : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
           return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
         }) {}
 
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 942e2dd..55d5925 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -37,21 +37,20 @@
   xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
 
   // Transfer parameters.
-  std::unique_ptr<xla::Literal> param0_literal =
+  xla::Literal param0_literal =
       xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
   std::unique_ptr<xla::GlobalData> param0_data =
-      client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::LiteralUtil::CreateR2<float>(
-          {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+  xla::Literal param1_literal = xla::LiteralUtil::CreateR2<float>(
+      {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
   std::unique_ptr<xla::GlobalData> param1_data =
-      client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client->TransferToServer(param1_literal).ConsumeValueOrDie();
 
   // Build computation.
   xla::XlaBuilder builder("");
-  auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Add(p1, p0, {0});
 
   xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
@@ -59,17 +58,16 @@
 
   // Execute and transfer result of computation.
   xla::ExecutionProfile profile;
-  xla::StatusOr<std::unique_ptr<xla::Literal>> result =
-      client->ExecuteAndTransfer(
-          computation,
-          /*arguments=*/{param0_data.get(), param1_data.get()},
-          /*execution_options=*/nullptr,
-          /*execution_profile=*/&profile);
-  std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+  xla::StatusOr<xla::Literal> result = client->ExecuteAndTransfer(
+      computation,
+      /*arguments=*/{param0_data.get(), param1_data.get()},
+      /*execution_options=*/nullptr,
+      /*execution_profile=*/&profile);
+  xla::Literal actual = result.ConsumeValueOrDie();
 
   LOG(INFO) << absl::StrFormat("computation took %dns",
                                profile.compute_time_ns());
-  LOG(INFO) << actual->ToString();
+  LOG(INFO) << actual.ToString();
 
   return 0;
 }
diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
index 7d8e51f..1a3d82d 100644
--- a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc
@@ -19,14 +19,14 @@
 #include <random>
 
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/util.h"
 
 namespace xla {
 namespace cpu {
 namespace {
 
-class ShapePartitionAssignerTest : public HloTestBase {
+class ShapePartitionAssignerTest : public HloVerifiedTestBase {
  protected:
   typedef std::vector<int64> Vec;
 
@@ -91,7 +91,7 @@
             expected_partitions);
 }
 
-class ShapePartitionIteratorTest : public HloTestBase {
+class ShapePartitionIteratorTest : public HloVerifiedTestBase {
  protected:
   typedef std::vector<std::pair<int64, int64>> Partition;
 };
@@ -145,7 +145,7 @@
   }
 }
 
-class RandomShapePartitionIteratorTest : public HloTestBase {
+class RandomShapePartitionIteratorTest : public HloVerifiedTestBase {
  protected:
   typedef std::vector<std::pair<int64, int64>> Partition;
   RandomShapePartitionIteratorTest()
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 2384166..c55206e 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -48,6 +48,7 @@
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
@@ -121,6 +122,7 @@
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
         "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index fcd87b3..18ee25b 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
 #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace xla {
@@ -69,8 +70,7 @@
   HloInstruction* rhs = builder.AddInstruction(
       HloInstruction::CreateParameter(1, param_shape, "input"));
 
-  builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+  builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
   CompileAndCheck(builder.Build(), spec.filecheck_lines);
 }
 
@@ -87,8 +87,7 @@
   HloInstruction* lhs_transposed = builder.AddInstruction(
       HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
 
-  builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+  builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
   CompileAndCheck(builder.Build(), spec.filecheck_lines);
 }
 
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 2272105..1deb412 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -25,7 +25,7 @@
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/test.h"
@@ -34,7 +34,7 @@
 namespace cpu {
 namespace {
 
-class CpuFusionTest : public HloTestBase {
+class CpuFusionTest : public HloVerifiedTestBase {
  protected:
   CpuFusionTest() {}
 
@@ -45,7 +45,7 @@
   auto builder = HloComputation::Builder(TestName());
   auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
   auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
-  Shape vshape = input_literal1->shape();
+  Shape vshape = input_literal1.shape();
 
   auto input1 = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal1)));
@@ -61,7 +61,7 @@
   module->AddEntryComputation(builder.Build());
 
   CpuInstructionFusion fusion;
-  EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(fusion.Run(module).ValueOrDie());
 
   // The computation root instruction was fused. Verify the fusion instruction
   // is now the root.
@@ -75,16 +75,16 @@
   EXPECT_EQ(4, fusion_instruction->fused_instruction_count());
 
   // Compile and execute the computation.
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
 
   // Check the output correctness.
-  LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, *result, error_spec_);
+  LiteralTestUtil::ExpectR1Near<float>({1.0, 40.0, -5.0}, result, error_spec_);
 }
 
 TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
   auto builder = HloComputation::Builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
-  Shape vshape = input_literal->shape();
+  Shape vshape = input_literal.shape();
 
   auto input = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
@@ -108,7 +108,7 @@
   module->AddEntryComputation(builder.Build());
 
   CpuInstructionFusion fusion;
-  EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(fusion.Run(module).ValueOrDie());
 
   // The computation root instruction was fused. Verify the fusion instruction
   // is now the root.
@@ -122,11 +122,10 @@
   EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
 
   // Compile and execute the computation.
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
 
   // Check the output correctness.
-  LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, *result,
-                                       error_spec_);
+  LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0}, result, error_spec_);
 }
 
 TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) {
@@ -135,7 +134,7 @@
   auto module = CreateNewModule();
   auto builder = HloComputation::Builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
-  Shape vshape = input_literal->shape();
+  Shape vshape = input_literal.shape();
 
   auto input = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
@@ -184,7 +183,7 @@
   module->AddEntryComputation(builder.Build());
 
   CpuInstructionFusion fusion;
-  EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(fusion.Run(module).ValueOrDie());
 
   // The computation root instruction was fused. Verify the fusion instruction
   // is now the root.
@@ -209,11 +208,11 @@
       << fusion_instruction2->fused_instructions_computation()->ToString();
 
   // Compile and execute the computation.
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
 
   // Check the output correctness.
   LiteralTestUtil::ExpectR1Near<float>({14.0, 40.0, 40.0, 14.0, 40.0, 40.0},
-                                       *result, error_spec_);
+                                       result, error_spec_);
 }
 
 TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
@@ -232,7 +231,7 @@
   // each fusion instruction to ensure that negate is not duplicated.
   auto builder = HloComputation::Builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
-  Shape vshape = input_literal->shape();
+  Shape vshape = input_literal.shape();
 
   auto constant = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
@@ -256,7 +255,7 @@
 
   // Run fusion.
   CpuInstructionFusion fusion;
-  EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(fusion.Run(module).ValueOrDie());
 
   auto fusion1 = result->operand(0);
   auto fusion2 = result->operand(1);
@@ -315,7 +314,7 @@
   module->AddEntryComputation(builder.Build());
 
   CpuInstructionFusion fusion;
-  EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(fusion.Run(module).ValueOrDie());
 
   // The only fusion instruction should be operand 0 of the tuple (formerly
   // negate1).
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c35569c..5cc6d01 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -58,52 +58,52 @@
 };
 
 TEST_F(InfeedTest, SingleInfeedR0Bool) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+  TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
 }
 
 TEST_F(InfeedTest, SingleInfeedR1U32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+  TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 }
 
 TEST_F(InfeedTest, SingleInfeedR2F32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+  TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 }
 
 TEST_F(InfeedTest, SingleInfeedR3F32) {
   TestInfeedRoundTrip(
-      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
-                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+      LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+                             {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 }
 
 TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
   const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
   const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
 
-  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
       {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
        {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
       r3_dim0minor));
 
-  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
       {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
        {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
       r3_dim0major));
 }
 
 TEST_F(InfeedTest, SingleInfeedR4S32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+  TestInfeedRoundTrip(LiteralUtil::CreateR4(
       {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
        {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 }
 
 TEST_F(InfeedTest, SingleInfeedTuple) {
-  TestInfeedRoundTrip(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
-                               LiteralUtil::CreateR0<bool>(false).get()}));
+  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+       LiteralUtil::CreateR0<bool>(false)}));
 }
 
 TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
-  TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+  TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
 }
 
 // Tests Infeed operation used in a while loop, as in the code below. The
@@ -157,21 +157,21 @@
 
   // Send 5 Infeed data of shape F32[3].
   ASSERT_IS_OK(
-      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({1, 2, 3})));
   ASSERT_IS_OK(
-      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
+      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({4, 5, 6})));
   ASSERT_IS_OK(
-      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({7, 8, 9})));
   ASSERT_IS_OK(
-      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({10, 11, 12})));
   ASSERT_IS_OK(
-      client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
+      client_->TransferToInfeed(LiteralUtil::CreateR1<float>({13, 14, 15})));
 
   delete computation_thread;  // Joins the thread.
   auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
 
   // Only the first 3 infeed data should be added.
-  LiteralTestUtil::ExpectR0Near<float>(45.0f, *result_literal, ErrorSpec{1e-7});
+  LiteralTestUtil::ExpectR0Near<float>(45.0f, result_literal, ErrorSpec{1e-7});
 }
 
 // Tests two Infeed operations with a total order. The order is enforced by
@@ -250,17 +250,17 @@
 
   // Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
-                               LiteralUtil::CreateR0<bool>(true).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+                                        LiteralUtil::CreateR0<bool>(true)})));
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
-                               LiteralUtil::CreateR0<bool>(true).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({3, 4}),
+                                        LiteralUtil::CreateR0<bool>(true)})));
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
-                               LiteralUtil::CreateR0<bool>(true).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({5, 6}),
+                                        LiteralUtil::CreateR0<bool>(true)})));
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
-                               LiteralUtil::CreateR0<bool>(false).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8}),
+                                        LiteralUtil::CreateR0<bool>(false)})));
 
   // Asynchronously launch the execution on the device.
   std::unique_ptr<GlobalData> result;
@@ -275,21 +275,21 @@
   // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
   sleep(1);
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
-                               LiteralUtil::CreateR0<bool>(true).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2, 3}),
+                                        LiteralUtil::CreateR0<bool>(true)})));
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
-                               LiteralUtil::CreateR0<bool>(false).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({7, 8, 9}),
+                                        LiteralUtil::CreateR0<bool>(false)})));
   ASSERT_IS_OK(client_->TransferToInfeed(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
-                               LiteralUtil::CreateR0<bool>(true).get()})));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({4, 5, 6}),
+                                        LiteralUtil::CreateR0<bool>(true)})));
 
   // Wait for the execution to be done, and transfer the result.
   delete computation_thread;  // Joins the thread.
   auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
 
   // Only the first 6 infeed data should be added.
-  LiteralTestUtil::ExpectR0Near<float>(66.0f, *result_literal, ErrorSpec{1e-7});
+  LiteralTestUtil::ExpectR0Near<float>(66.0f, result_literal, ErrorSpec{1e-7});
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index bb10519..7af51db 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -41,8 +41,7 @@
 TEST_F(CpuNoAliasTest, Concat) {
   HloComputation::Builder builder(TestName());
 
-  std::unique_ptr<Literal> literal =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
   HloInstruction* param_x = builder.AddInstruction(
       HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc
index 09cb10d..b2ba261 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.cc
+++ b/tensorflow/compiler/xla/service/dot_decomposer.cc
@@ -134,9 +134,9 @@
     DotDimensionNumbers dot_dnums;
     dot_dnums.add_lhs_contracting_dimensions(1);
     dot_dnums.add_rhs_contracting_dimensions(0);
-    auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot(
-        dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums));
-    dot_r2->set_precision_config(dot->precision_config());
+    auto dot_r2 = computation->AddInstruction(
+        HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2,
+                                  dot_dnums, dot->precision_config()));
 
     // Reshape Dot to R3 so we can concat along batch dimension.
     auto dot_r3 = computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index 1b3be19..852f34e 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -56,9 +56,9 @@
 }
 )";
 
-  std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
-  std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
-  RunTest(hlo_text, {lhs.get(), rhs.get()});
+  Literal lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+  Literal rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
+  RunTest(hlo_text, {&lhs, &rhs});
 }
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f66082..5fbd73a 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@
 namespace xla {
 namespace {
 
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
  protected:
   // Build and return a trivial computation taking and returning a scalar.
   std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@
   }
 
   {
-    TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+    TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
     EXPECT_TRUE(result);
-    std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+    std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
     const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
     EXPECT_EQ(1, c_node.caller_callsites().size());
   }
@@ -176,15 +176,15 @@
   }
 
   {
-    std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+    std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
     EXPECT_EQ(2, cond_node.caller_callsites().size());
   }
 
   {
-    TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+    TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
     EXPECT_TRUE(result);
-    std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+    std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
     const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
     EXPECT_EQ(1, cond_node.caller_callsites().size());
   }
@@ -211,9 +211,9 @@
   module->AddEntryComputation(
       MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
 
-  TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
   EXPECT_TRUE(result);
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   EXPECT_EQ(7, module->computation_count());
 
   const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@
   module->AddEntryComputation(builder.Build());
   EXPECT_EQ(2, module->computation_count());
 
-  TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
   EXPECT_TRUE(result);
-  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+  std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   // The true and false computations must now be different.
   EXPECT_EQ(3, module->computation_count());
 
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 4ed91ef..bec02e1 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -125,7 +125,7 @@
                        device_memory.size());
           // Element is array-shaped: transfer array data to device buffer.
           const auto subliteral = LiteralSlice(literal, index);
-          std::unique_ptr<Literal> relayed_out_literal;
+          Literal relayed_out_literal;
           const void* source;
           if (LayoutUtil::Equal(device_subshape.layout(),
                                 subliteral.shape().layout())) {
@@ -138,7 +138,7 @@
             // Relayout data before transferring.
             relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
                                                       /*shape_index=*/{});
-            source = relayed_out_literal->untyped_data();
+            source = relayed_out_literal.untyped_data();
             TF_RETURN_IF_ERROR(TransferBufferToDevice(
                 stream,
                 /*size=*/GetByteSizeRequirement(device_subshape), source,
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a68b7a1..64b9683 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,8 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/memory",
@@ -172,6 +174,7 @@
         "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:elemental_ir_emitter",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:name_uniquer",
         "//tensorflow/compiler/xla/service:while_loop_analysis",
         "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
@@ -369,6 +372,8 @@
     srcs = ["ir_emission_utils.cc"],
     hdrs = ["ir_emission_utils.h"],
     deps = [
+        ":backend_configs",
+        ":cudnn_convolution_runner",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:window_util",
@@ -394,6 +399,7 @@
         "//tensorflow/compiler/xla/service:compiler",
         "//tensorflow/compiler/xla/service:device_memory_allocator",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:hlo_pass",
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
@@ -480,6 +486,7 @@
         "//tensorflow/compiler/xla/service:hlo_matchers",
         "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
     ],
 )
@@ -811,9 +818,9 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/service:buffer_value",
         "//tensorflow/compiler/xla/service:hlo",
+        "//tensorflow/compiler/xla/service:hlo_memory_scheduler",
         "//tensorflow/compiler/xla/service:hlo_ordering",
         "//tensorflow/compiler/xla/service:hlo_reachability",
-        "//tensorflow/compiler/xla/service:hlo_scheduling",
         "@com_google_absl//absl/memory",
     ],
 )
@@ -830,6 +837,8 @@
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings:str_format",
@@ -898,6 +907,7 @@
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
+        "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 05448d8..3a23ac1 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -20,6 +20,7 @@
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/platform/logging.h"
@@ -30,62 +31,32 @@
 
 using se::dnn::AlgorithmDesc;
 
-ConvolutionThunk::ConvolutionThunk(
-    CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer,
-    const BufferAllocation::Slice& filter_buffer,
-    const BufferAllocation::Slice& output_buffer,
-    const BufferAllocation::Slice& tuple_result_buffer,
-    const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
-    const Shape& filter_shape, const Shape& output_shape, const Window& window,
-    const ConvolutionDimensionNumbers& dim_nums, int64 feature_group_count,
-    int64 algorithm, bool tensor_ops_enabled, const HloInstruction* hlo)
-    : Thunk(Kind::kConvolution, hlo),
-      convolution_kind_(convolution_kind),
-      input_buffer_(input_buffer),
-      filter_buffer_(filter_buffer),
-      output_buffer_(output_buffer),
-      tuple_result_buffer_(tuple_result_buffer),
-      scratch_buffer_(scratch_buffer),
-      input_shape_(input_shape),
-      filter_shape_(filter_shape),
-      output_shape_(output_shape),
-      window_(window),
-      dim_nums_(dim_nums),
-      feature_group_count_(feature_group_count),
-      algorithm_(algorithm),
-      tensor_ops_enabled_(tensor_ops_enabled) {}
-
 Status ConvolutionThunk::ExecuteOnStream(
     const BufferAllocations& buffer_allocations, se::Stream* stream,
     HloExecutionProfiler* profiler) {
-  se::DeviceMemoryBase input_data =
-      buffer_allocations.GetDeviceAddress(input_buffer_);
-  se::DeviceMemoryBase filter_data =
-      buffer_allocations.GetDeviceAddress(filter_buffer_);
-  se::DeviceMemoryBase output_data =
-      buffer_allocations.GetDeviceAddress(output_buffer_);
+  CudnnConvParams params;
+
+  params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
+  params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
+  params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
   se::DeviceMemoryBase scratch =
       buffer_allocations.GetDeviceAddress(scratch_buffer_);
 
-  se::dnn::AlgorithmConfig algorithm_config(
-      se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+  TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
 
   auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
-  TF_RETURN_IF_ERROR(RunCudnnConvolution(
-      convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
-      filter_data, output_data, scratch, window_, dim_nums_,
-      feature_group_count_, algorithm_config, stream));
+  TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
 
   // Figure out which of output/input/filter is the result produced by
   // this op, and write the result tuple.
   void* result_ptr = [&] {
-    switch (convolution_kind_) {
+    switch (params.kind) {
       case CudnnConvKind::kForward:
-        return output_data.opaque();
+        return params.output_buf.opaque();
       case CudnnConvKind::kBackwardInput:
-        return input_data.opaque();
+        return params.input_buf.opaque();
       case CudnnConvKind::kBackwardFilter:
-        return filter_data.opaque();
+        return params.filter_buf.opaque();
     }
   }();
   void* ptrs[] = {result_ptr, scratch.opaque()};
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 68d67c4..d7d1f91 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -24,6 +24,7 @@
 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -32,7 +33,7 @@
 namespace xla {
 namespace gpu {
 
-// This class stores everything that StreamExecutor needs to launch a BNN
+// This class stores everything that StreamExecutor needs to launch a DNN
 // convolution. It is generated by IrEmitter.
 //
 // This is thread-compatible.
@@ -41,27 +42,24 @@
   // Constructs a thunk for launching a DNN convolution.  When run, it will
   // write a tuple (result, scratch_memory) into `tuple_result_buffer`.
   //
-  // `algorithm` is a cudnn algorithm number.  `algorithm == -1` indicates that
-  // we should use the default (i.e. baseline) cudnn algorithm.
-  //
   // Note that "output" here doesn't refer to the output from running this
   // thunk, but rather to the "output" of a hypothetical forward convolution
   // that corresponds to this input+filter+output triple.  That is, the result
   // generated by this thunk is "output" for forward convs, "input" for
   // backward-input convs, and "filter" for backward-filter convs.
-  //
-  // Semantics of null hlo_instruction argument are as in Thunk.
-  ConvolutionThunk(CudnnConvKind convolution_kind,
-                   const BufferAllocation::Slice& input_buffer,
-                   const BufferAllocation::Slice& filter_buffer,
-                   const BufferAllocation::Slice& output_buffer,
-                   const BufferAllocation::Slice& tuple_result_buffer,
-                   const BufferAllocation::Slice& scratch_buffer,
-                   const Shape& input_shape, const Shape& filter_shape,
-                   const Shape& output_shape, const Window& window,
-                   const ConvolutionDimensionNumbers& dim_nums,
-                   int64 feature_group_count, int64 algorithm,
-                   bool tensor_ops_enabled, const HloInstruction* hlo);
+  ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
+                   BufferAllocation::Slice input_slice,
+                   BufferAllocation::Slice filter_slice,
+                   BufferAllocation::Slice output_slice,
+                   BufferAllocation::Slice scratch_slice,
+                   BufferAllocation::Slice tuple_result_slice)
+      : Thunk(Kind::kConvolution, cudnn_call),
+        cudnn_call_(cudnn_call),
+        input_buffer_(std::move(input_slice)),
+        filter_buffer_(std::move(filter_slice)),
+        output_buffer_(std::move(output_slice)),
+        scratch_buffer_(std::move(scratch_slice)),
+        tuple_result_buffer_(std::move(tuple_result_slice)) {}
 
   ConvolutionThunk(const ConvolutionThunk&) = delete;
   ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -72,23 +70,12 @@
                          HloExecutionProfiler* profiler) override;
 
  private:
-  const CudnnConvKind convolution_kind_;
-
-  const BufferAllocation::Slice input_buffer_;
-  const BufferAllocation::Slice filter_buffer_;
-  const BufferAllocation::Slice output_buffer_;
-  const BufferAllocation::Slice tuple_result_buffer_;
-  const BufferAllocation::Slice scratch_buffer_;
-
-  const Shape input_shape_;
-  const Shape filter_shape_;
-  const Shape output_shape_;
-
-  const Window window_;
-  const ConvolutionDimensionNumbers dim_nums_;
-  int64 feature_group_count_;
-  int64 algorithm_;
-  bool tensor_ops_enabled_;
+  const HloCustomCallInstruction* cudnn_call_;
+  BufferAllocation::Slice input_buffer_;
+  BufferAllocation::Slice filter_buffer_;
+  BufferAllocation::Slice output_buffer_;
+  BufferAllocation::Slice scratch_buffer_;
+  BufferAllocation::Slice tuple_result_buffer_;
 };
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 5c25551..c607aea 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/platform/mutex.h"
 
@@ -176,10 +177,14 @@
 // caching would speed up compilation a lot.
 StatusOr<std::tuple<int64, bool, int64>>
 CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    HloInstruction* instr) {
+    const HloCustomCallInstruction* instr) {
+  CudnnConvParams params;
+  TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
+
+  const Shape& input_shape = *params.input_shape;
+  const Shape& filter_shape = *params.filter_shape;
+  const Shape& output_shape = *params.output_shape;
+
   CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
   CHECK_EQ(input_shape.element_type(), output_shape.element_type());
   // TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -220,13 +225,13 @@
   // use a ScratchAllocator for this instead of calling allocator_ directly so
   // that our allocations don't leak.
   ScratchAllocator input_output_allocator(device_ordinal, allocator);
-  TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
+  TF_ASSIGN_OR_RETURN(params.input_buf,
                       input_output_allocator.AllocateBytes(
                           &stream, ShapeUtil::ByteSizeOf(input_shape)));
-  TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
+  TF_ASSIGN_OR_RETURN(params.filter_buf,
                       input_output_allocator.AllocateBytes(
                           &stream, ShapeUtil::ByteSizeOf(filter_shape)));
-  TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
+  TF_ASSIGN_OR_RETURN(params.output_buf,
                       input_output_allocator.AllocateBytes(
                           &stream, ShapeUtil::ByteSizeOf(output_shape)));
 
@@ -253,32 +258,32 @@
           static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
       stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
     };
-    initialize_f16(input_buf);
-    initialize_f16(filter_buf);
-    initialize_f16(output_buf);
+    initialize_f16(params.input_buf);
+    initialize_f16(params.filter_buf);
+    initialize_f16(params.output_buf);
   } else {
     // Although we don't have evidence this matters, zero out the buffers before
     // autotuning.  It's conceivable that using uninitialized memory as the
     // inputs might affect performance if e.g. the inputs contain denormals, and
     // this is easy enough.
-    stream.ThenMemZero(&input_buf, input_buf.size())
-        .ThenMemZero(&filter_buf, filter_buf.size())
-        .ThenMemZero(&output_buf, output_buf.size());
+    stream.ThenMemZero(&params.input_buf, params.input_buf.size())
+        .ThenMemZero(&params.filter_buf, params.filter_buf.size())
+        .ThenMemZero(&params.output_buf, params.output_buf.size());
   }
 
   DeviceMemoryBase* result_buf = [&] {
-    switch (kind) {
+    switch (params.kind) {
       case CudnnConvKind::kBackwardFilter:
-        return &filter_buf;
+        return &params.filter_buf;
       case CudnnConvKind::kBackwardInput:
-        return &input_buf;
+        return &params.input_buf;
       case CudnnConvKind::kForward:
-        return &output_buf;
+        return &params.output_buf;
     }
   }();
 
   const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
-      input_shape, output_shape, dnums, stream_exec_);
+      input_shape, output_shape, *params.dnums, stream_exec_);
   se::dnn::ProfileResult best_result;
   int64 best_result_bytes_used = 0;
 
@@ -288,18 +293,16 @@
   // this algorithm considered correct, though.
   optional<AlgorithmDesc> first_algorithm;
   for (const AlgorithmDesc& alg :
-       GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) {
+       GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
     ScratchAllocator scratch_allocator(device_ordinal, allocator);
     se::dnn::ProfileResult profile_result;
     VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
             << instr->ToString();
 
-    bool launch_ok =
-        RunCudnnConvolution(
-            kind, input_shape, filter_shape, output_shape, input_buf,
-            filter_buf, output_buf, &scratch_allocator, window, dnums,
-            feature_group_count, AlgorithmConfig(alg), &stream, &profile_result)
-            .ok();
+    params.algorithm = AlgorithmConfig(alg);
+    bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
+                                         &profile_result)
+                         .ok();
 
     if (launch_ok && profile_result.is_valid()) {
       const bool crash_on_checking_failure =
@@ -374,34 +377,8 @@
     HloInstruction* instr) {
   CHECK(IsCustomCallToDnnConvolution(*instr));
 
-  const auto& call_target = instr->custom_call_target();
-  const auto& lhs_shape = instr->operand(0)->shape();
-  const auto& rhs_shape = instr->operand(1)->shape();
-  const auto& conv_result_shape = instr->shape().tuple_shapes(0);
-  StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
-  if (call_target == kCudnnConvForwardCallTarget) {
-    alg_scratch_and_tc =
-        PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
-                          /*filter_shape=*/rhs_shape,
-                          /*output_shape=*/conv_result_shape, instr->window(),
-                          instr->convolution_dimension_numbers(),
-                          instr->feature_group_count(), instr);
-  } else if (call_target == kCudnnConvBackwardInputCallTarget) {
-    alg_scratch_and_tc = PickBestAlgorithm(
-        CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
-        /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
-        instr->convolution_dimension_numbers(), instr->feature_group_count(),
-        instr);
-  } else if (call_target == kCudnnConvBackwardFilterCallTarget) {
-    alg_scratch_and_tc = PickBestAlgorithm(
-        CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
-        /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
-        instr->window(), instr->convolution_dimension_numbers(),
-        instr->feature_group_count(), instr);
-  } else {
-    LOG(FATAL) << "Unknown custom call target for cudnn conv: "
-               << instr->ToString();
-  }
+  StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+      PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
 
   if (!alg_scratch_and_tc.ok()) {
     LOG(ERROR) << alg_scratch_and_tc.status();
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index 0cb0116..f79b113 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -20,6 +20,7 @@
 #include "tensorflow/compiler/xla/service/compiler.h"
 #include "tensorflow/compiler/xla/service/device_memory_allocator.h"
 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -49,10 +50,7 @@
   StatusOr<bool> RunOnComputation(HloComputation* computation);
   StatusOr<bool> RunOnInstruction(HloInstruction* instr);
   StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
-      CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-      const Shape& output_shape, const Window& window,
-      const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-      HloInstruction* instr);
+      const HloCustomCallInstruction* instr);
 
   se::StreamExecutor* stream_exec_;                   // never null
   DeviceMemoryAllocator* allocator_;                  // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 9bf721e..228379a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
 
+#include <cstdlib>
 #include <numeric>
 #include <vector>
 
@@ -59,8 +60,6 @@
     HloInstruction* conv) {
   const auto no_match_result =
       std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
-  // TODO(b/31709653): Figure out if we can use grouped convolutions also on
-  // backward filter.
   if (conv->feature_group_count() > 1) {
     return no_match_result;
   }
@@ -218,13 +217,16 @@
 
 // Try to match a backward input pattern that contains "conv".
 // Precondition: "conv" is a kConvolution.
-std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput(
-    HloInstruction* conv) {
+std::tuple<bool, Window, ConvolutionDimensionNumbers, HloInstruction*>
+MatchBackwardInput(HloInstruction* conv) {
   const auto no_match_result =
-      std::make_tuple(false, Window(), ConvolutionDimensionNumbers());
+      std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr);
 
-  // TODO(b/31709653): Figure out if we can use grouped convolutions also on
-  // backward input.
+  // TODO(b/31709653): Theoretically cuDNN supports grouped convolutions also
+  // for the backward input convolution, but at least for now with version 7.1.4
+  // it is slower. This needs to be re-evaluated for future cuDNN versions.
+  // Note that we already have the necessary code down below, the only thing to
+  // enable it is to remove the following early return.
   if (conv->feature_group_count() > 1) {
     return no_match_result;
   }
@@ -232,51 +234,38 @@
   // Match instruction pattern.
   CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
   HloInstruction* reverse_filter = conv->mutable_operand(1);
-
-  // Match the reverse of the filter.
   ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
-  const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
-  if (reverse_filter->opcode() == HloOpcode::kReverse) {
-    if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
-        !std::is_permutation(kernel_spatial_dims.begin(),
-                             kernel_spatial_dims.end(),
-                             reverse_filter->dimensions().begin())) {
-      VLOG(1)
-          << "Backward input convolution should reverse all kernel dimensions.";
-      return no_match_result;
-    }
-  } else if (reverse_filter->IsConstant()) {
-    // If the filter is a constant, we're willing to pattern-match to a
-    // backwards-input conv, on the theory that
-    //
-    //  a) reversing a constant is free, and
-    //  b) even if the user specified this filter as reverse(constant), we would
-    //     long ago have constant-folded away the reverse.
-    //
-    // If the constant has any other uses, reversing it isn't entirely free,
-    // since we'd now have two constants to keep in memory.  But hopefully it's
-    // free enough.
-    //
-    // TODO(jlebar): Should we do this even if the filter is not a constant?
-    // Reversing a non-constant filter is probably cheaper than padding the
-    // input!
 
-    // Nothing to do, just fall through.
-  } else {
-    // Possibly 1x1 filter.
-    for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
-      if (conv->window().dimensions(i).size() != 1) {
-        VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
-                << reverse_filter->ToString();
-        return no_match_result;
-      }
-    }
-    if (!window_util::HasBaseDilation(conv->window())) {
-      VLOG(1) << conv->ToString()
-              << " is a regular forward convolution. No need "
-                 "to fold it to a backward input convolution.";
-      return no_match_result;
-    }
+  // We pattern-match to a backwards input conv if:
+  //
+  //  - all spatial dims of the filter are reversed
+  //
+  // OR
+  //
+  //  - filter is 1x1 or a constant AND
+  //  - conv has base dilation (otherwise this is just a regular forward conv).
+  //
+  // The final criterion above is just for canonicalization; cudnn seems to run
+  // just as fast if we canonicalize 1x1/constant filters without base dilation
+  // to forward or backward convs.  We canonicalize to forward conv because (a)
+  // it's more natural (constant filters usually show up when doing inference,
+  // and having backwards convolutions in inference graphs would be weird), and
+  // (b) cudnn has special fusions for forward conv plus bias and activation,
+  // and we want to pattern-match to that after running this pass.
+  bool is_reversed_filter =
+      reverse_filter->opcode() == HloOpcode::kReverse &&
+      absl::c_is_permutation(dnums.kernel_spatial_dimensions(),
+                             reverse_filter->dimensions());
+  bool is_1x1_filter =
+      absl::c_all_of(conv->window().dimensions(),
+                     [](const WindowDimension& d) { return d.size() == 1; });
+  if (!is_reversed_filter &&
+      !(window_util::HasBaseDilation(conv->window()) &&
+        (reverse_filter->IsConstant() || is_1x1_filter))) {
+    VLOG(1) << "Can't match to backwards convolution. Either filter is not "
+               "kReverse, or it's not a base-dilated conv with a 1x1 or "
+               "constant filter.";
+    return no_match_result;
   }
 
   // Match padding and dilation of the forward convolution.
@@ -401,26 +390,64 @@
     }
   }
 
-  // OK, it's a match!  Canonicalize the conv's filter so that it's a reverse.
-  // This simplifies things for our caller, and algebraic-simplifier will later
-  // remove any unnecessary reverses.
-  if (reverse_filter->opcode() != HloOpcode::kReverse) {
-    // Create a double-reverse, which is a nop.
-    HloComputation* c = conv->parent();
-    reverse_filter = c->AddInstruction(
-        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
-                                      AsInt64Slice(kernel_spatial_dims)));
-    reverse_filter = c->AddInstruction(
-        HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
-                                      AsInt64Slice(kernel_spatial_dims)));
-    TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
-  }
-
+  // OK, it's a match! Switch the input feature dimension with the output
+  // feature dimension. This is the way cuDNN expects it to be.
   dnums.set_kernel_input_feature_dimension(
       conv->convolution_dimension_numbers().kernel_output_feature_dimension());
   dnums.set_kernel_output_feature_dimension(
       conv->convolution_dimension_numbers().kernel_input_feature_dimension());
-  return std::make_tuple(true, new_window, dnums);
+
+  // If we matched against a constant, we need to add a reverse op that can be
+  // subsumed by the cuDNN call. algebraic-simplifier will later remove any
+  // unnecessary reverses.
+  if (reverse_filter->opcode() != HloOpcode::kReverse &&
+      reverse_filter->IsConstant()) {
+    // Create a double-reverse, which is a nop.
+    HloComputation* c = conv->parent();
+    reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+        reverse_filter->shape(), reverse_filter,
+        AsInt64Slice(dnums.kernel_spatial_dimensions())));
+    reverse_filter = c->AddInstruction(HloInstruction::CreateReverse(
+        reverse_filter->shape(), reverse_filter,
+        AsInt64Slice(dnums.kernel_spatial_dimensions())));
+    TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter));
+  }
+
+  // Calculate the 'rhs' that goes into the backward input convolution.
+  HloInstruction* rhs = reverse_filter;
+  // One reverse is subsumed by the cuDNN call.
+  if (rhs->opcode() == HloOpcode::kReverse) {
+    rhs = rhs->mutable_operand(0);
+  }
+  if (conv->feature_group_count() == 1) {
+    return std::make_tuple(true, new_window, dnums, rhs);
+  }
+
+  // Handle grouped convolutions. Because we swapped the input feature dimension
+  // with the output feature dimension, we need to also reshape the kernel so
+  // that the 'feature_group_count' parameter still makes sense. The
+  // 'feature_group_count' parameter essentially specifies how often the
+  // 'kernel_input_feature_dimension' is repeated. So when we swap these
+  // dimensions, we need to divide the new 'kernel_input_feature_dimension' by
+  // 'feature_group_count' and multiply the new
+  // 'kernel_output_feature_dimension' by 'feature_group_count'.
+  Shape new_shape = rhs->shape();
+  int64 input_feature_dimension = dnums.kernel_input_feature_dimension();
+  int64 output_feature_dimension = dnums.kernel_output_feature_dimension();
+
+  // In the backward convolution case, the spatial dimensions become the
+  // feature dimensions, and we are guaranteed that the spatial dimensions are
+  // adjacent.
+  CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL);
+  int64 input_features = new_shape.dimensions(input_feature_dimension);
+  int64 output_features = new_shape.dimensions(output_feature_dimension);
+  new_shape.set_dimensions(input_feature_dimension,
+                           input_features / conv->feature_group_count());
+  new_shape.set_dimensions(output_feature_dimension,
+                           output_features * conv->feature_group_count());
+  HloComputation* c = conv->parent();
+  rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs));
+  return std::make_tuple(true, new_window, dnums, rhs);
 }
 
 // Tries to rewrite a single convolution into a call to cudnn.
@@ -431,6 +458,7 @@
     bool match;
     Window window;
     ConvolutionDimensionNumbers dnums;
+    HloInstruction* rhs;
 
     std::tie(match, window, dnums) = MatchBackwardFilter(conv);
     if (match) {
@@ -439,13 +467,8 @@
           window, dnums, conv->feature_group_count());
     }
 
-    std::tie(match, window, dnums) = MatchBackwardInput(conv);
+    std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
     if (match) {
-      // Backward input conv subsumes the conv plus the reverse in operand 1.
-      HloInstruction* reverse = conv->mutable_operand(1);
-      CHECK_EQ(reverse->opcode(), HloOpcode::kReverse);
-      HloInstruction* rhs = reverse->mutable_operand(0);
-
       return CreateCudnnConvBackwardInput(conv->shape(),
                                           conv->mutable_operand(0), rhs, window,
                                           dnums, conv->feature_group_count());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
index 46c23db..d237f89 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc
@@ -107,12 +107,12 @@
   conv_window.mutable_dimensions(1)->set_size(2);
   conv_window.mutable_dimensions(1)->set_window_dilation(2);
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(activations->shape(),
-                                         gradients->shape(), conv_window,
-                                         tf_default_dnums_for_backward_filter_)
+      ShapeInference::InferConvolveShape(
+          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_filter_)
           .ConsumeValueOrDie(),
-      activations, gradients, conv_window,
-      tf_default_dnums_for_backward_filter_));
+      activations, gradients, /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -135,12 +135,12 @@
   Window conv_window = default_conv_window_;
   conv_window.mutable_dimensions(1)->set_size(3);
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(activations->shape(),
-                                         gradients->shape(), conv_window,
-                                         tf_default_dnums_for_backward_filter_)
+      ShapeInference::InferConvolveShape(
+          activations->shape(), gradients->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_filter_)
           .ConsumeValueOrDie(),
-      activations, gradients, conv_window,
-      tf_default_dnums_for_backward_filter_));
+      activations, gradients, /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -170,7 +170,8 @@
   }
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
-      conv_window, tf_default_dnums_for_backward_filter_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -200,7 +201,8 @@
   }
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
-      conv_window, tf_default_dnums_for_backward_filter_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -228,7 +230,8 @@
   }
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
-      conv_window, tf_default_dnums_for_backward_filter_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -272,13 +275,14 @@
 
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
-      /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+      /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window,
+      conv_dnums, DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(),
-      ShapeInference::InferConvolveShape(
-          output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
-          .ValueOrDie()));
+      conv->shape(), ShapeInference::InferConvolveShape(
+                         output->shape(), reverse_kernel->shape(),
+                         /*feature_group_count=*/1, conv_window, conv_dnums)
+                         .ValueOrDie()));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -319,11 +323,11 @@
 
   builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
-                                         conv_window,
+                                         /*feature_group_count=*/1, conv_window,
                                          tf_default_dnums_for_backward_input_)
           .ConsumeValueOrDie(),
-      /*lhs=*/output, /*rhs=*/kernel, conv_window,
-      tf_default_dnums_for_backward_input_));
+      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -350,12 +354,13 @@
           1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
 
   builder.AddInstruction(HloInstruction::CreateConvolve(
-      ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
-                                         default_conv_window_,
-                                         tf_default_dnums_for_backward_input_)
+      ShapeInference::InferConvolveShape(
+          output->shape(), kernel->shape(), /*feature_group_count=*/1,
+          default_conv_window_, tf_default_dnums_for_backward_input_)
           .ConsumeValueOrDie(),
-      /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
-      tf_default_dnums_for_backward_input_));
+      /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
+      default_conv_window_, tf_default_dnums_for_backward_input_,
+      DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -402,13 +407,15 @@
   }
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
-      conv_window, tf_default_dnums_for_backward_input_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(), conv_window,
-                         tf_default_dnums_for_backward_input_)
-                         .ValueOrDie()));
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_input_)
+          .ValueOrDie()));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -449,13 +456,15 @@
   }
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
-      conv_window, tf_default_dnums_for_backward_input_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(), conv_window,
-                         tf_default_dnums_for_backward_input_)
-                         .ValueOrDie()));
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_input_)
+          .ValueOrDie()));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -502,13 +511,15 @@
   forward_conv_col_dim->set_base_dilation(2);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
-      conv_window, tf_default_dnums_for_backward_input_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(), conv_window,
-                         tf_default_dnums_for_backward_input_)
-                         .ValueOrDie()));
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_input_)
+          .ValueOrDie()));
 
   auto module = CreateNewModule();
   const HloComputation* entry_computation =
@@ -554,13 +565,15 @@
   forward_conv_col_dim->set_padding_high(2);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
       ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
-      conv_window, tf_default_dnums_for_backward_input_));
+      /*feature_group_count=*/1, conv_window,
+      tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2)));
   // Verify the convolution's shape is consistent with ShapeInference.
   CHECK(ShapeUtil::Compatible(
-      conv->shape(), ShapeInference::InferConvolveShape(
-                         output->shape(), reverse_kernel->shape(), conv_window,
-                         tf_default_dnums_for_backward_input_)
-                         .ValueOrDie()));
+      conv->shape(),
+      ShapeInference::InferConvolveShape(
+          output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1,
+          conv_window, tf_default_dnums_for_backward_input_)
+          .ValueOrDie()));
 
   auto module = CreateNewModule();
   HloComputation* entry_computation =
@@ -577,7 +590,7 @@
   Array4D<float> constant_arr(4, 4, 2, 2);
   constant_arr.FillIota(0);
   string constant_str =
-      LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString();
+      LiteralUtil::CreateR4FromArray4D(constant_arr).ToString();
   ParseAndVerifyModule(absl::StrFormat(R"(
     HloModule test
 
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 05125e9..2a86ac2 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -72,14 +72,22 @@
 };
 
 template <typename T>
-Status RunCudnnConvolution(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, DeviceMemory<T> input_buf,
-    DeviceMemory<T> filter_buf, DeviceMemory<T> output_buf,
-    se::ScratchAllocator* scratch_allocator, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    AlgorithmConfig algorithm, Stream* stream,
-    ProfileResult* profile_result /*= nullptr*/) {
+Status RunCudnnConvolutionImpl(CudnnConvParams params,
+                               se::ScratchAllocator* scratch_allocator,
+                               se::Stream* stream,
+                               se::dnn::ProfileResult* profile_result) {
+  CudnnConvKind kind = params.kind;
+  const Shape& input_shape = *params.input_shape;
+  const Shape& filter_shape = *params.filter_shape;
+  const Shape& output_shape = *params.output_shape;
+  DeviceMemory<T> input_buf(params.input_buf);
+  DeviceMemory<T> filter_buf(params.filter_buf);
+  DeviceMemory<T> output_buf(params.output_buf);
+  const Window& window = *params.window;
+  const ConvolutionDimensionNumbers& dnums = *params.dnums;
+  int64 feature_group_count = params.feature_group_count;
+  AlgorithmConfig algorithm = params.algorithm;
+
   VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
   VLOG(3) << "tensor_ops_enabled: "
           << algorithm.algorithm().tensor_ops_enabled();
@@ -219,54 +227,31 @@
   }
 }
 
-Status RunCudnnConvolution(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, se::DeviceMemoryBase input_buf,
-    se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
-    se::DeviceMemoryBase scratch_buf, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
-    se::dnn::ProfileResult* profile_result) {
+Status RunCudnnConvolution(CudnnConvParams params,
+                           se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+                           se::dnn::ProfileResult* profile_result) {
   ScratchBufAllocator scratch_allocator(scratch_buf);
-  return RunCudnnConvolution(
-      kind, input_shape, filter_shape, output_shape, input_buf, filter_buf,
-      output_buf, &scratch_allocator, window, dnums, feature_group_count,
-      algorithm, stream, profile_result);
+  return RunCudnnConvolution(params, &scratch_allocator, stream,
+                             profile_result);
 }
 
-Status RunCudnnConvolution(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, se::DeviceMemoryBase input_buf,
-    se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
-    se::ScratchAllocator* scratch_allocator, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
-    se::dnn::ProfileResult* profile_result) {
-  PrimitiveType output_primitive_type = output_shape.element_type();
+Status RunCudnnConvolution(CudnnConvParams params,
+                           se::ScratchAllocator* scratch_allocator,
+                           se::Stream* stream,
+                           se::dnn::ProfileResult* profile_result) {
+  PrimitiveType output_primitive_type = params.output_shape->element_type();
   switch (output_primitive_type) {
     case F16:
-      return RunCudnnConvolution(
-          kind, input_shape, filter_shape, output_shape,
-          se::DeviceMemory<Eigen::half>(input_buf),
-          se::DeviceMemory<Eigen::half>(filter_buf),
-          se::DeviceMemory<Eigen::half>(output_buf), scratch_allocator, window,
-          dnums, feature_group_count, algorithm, stream, profile_result);
+      return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
+                                                  stream, profile_result);
     case F32:
-      return RunCudnnConvolution(
-          kind, input_shape, filter_shape, output_shape,
-          se::DeviceMemory<float>(input_buf),
-          se::DeviceMemory<float>(filter_buf),
-          se::DeviceMemory<float>(output_buf), scratch_allocator, window, dnums,
-          feature_group_count, algorithm, stream, profile_result);
+      return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
+                                            profile_result);
     case F64:
-      return RunCudnnConvolution(
-          kind, input_shape, filter_shape, output_shape,
-          se::DeviceMemory<double>(input_buf),
-          se::DeviceMemory<double>(filter_buf),
-          se::DeviceMemory<double>(output_buf), scratch_allocator, window,
-          dnums, feature_group_count, algorithm, stream, profile_result);
+      return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
+                                             profile_result);
     default:
-      LOG(FATAL) << ShapeUtil::HumanString(output_shape);
+      LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
   }
 }
 
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index a1b4fc7..381aa37 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -47,6 +47,20 @@
   kBackwardFilter,  // input  + output => filter
 };
 
+struct CudnnConvParams {
+  CudnnConvKind kind;
+  const Shape* input_shape;
+  const Shape* filter_shape;
+  const Shape* output_shape;
+  se::DeviceMemoryBase input_buf;
+  se::DeviceMemoryBase filter_buf;
+  se::DeviceMemoryBase output_buf;
+  const Window* window;
+  const ConvolutionDimensionNumbers* dnums;
+  int64 feature_group_count;
+  se::dnn::AlgorithmConfig algorithm;
+};
+
 // Converts a CudnnConvKind value to a string.
 string CudnnConvKindToString(CudnnConvKind kind);
 
@@ -55,10 +69,9 @@
 // Note that depending on the value of CudnnConvKind, the result of this call
 // may be written into input_buf, filter_buf, or output_buf!
 //
-// At the moment we only support cudnn convolutions over float and half, and
-// convolution with half data type is implemented with cudnn PSEUDO_HALF
-// configuration, that is, the input values are half and the internal
-// computation type is float.
+// At the moment convolution with half data type is implemented with cudnn
+// PSEUDO_HALF configuration, that is, the input values are half and the
+// internal computation type is float.
 //
 // We provide one overload which takes a scratch buffer, and another which takes
 // an allocator which is responsible for allocating the scratch space.  In
@@ -70,23 +83,14 @@
 // allocator and take note of how much memory is used.  The next time you call
 // the same conv, you can provide an explicitly preallocated scratch buffer of
 // that size, if you like.
-Status RunCudnnConvolution(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, se::DeviceMemoryBase input_buf,
-    se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
-    se::DeviceMemoryBase scratch_buf, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
-    se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+                           se::DeviceMemoryBase scratch_buf, se::Stream* stream,
+                           se::dnn::ProfileResult* profile_result = nullptr);
 
-Status RunCudnnConvolution(
-    CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
-    const Shape& output_shape, se::DeviceMemoryBase input_buf,
-    se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf,
-    se::ScratchAllocator* scratch_allocator, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count,
-    se::dnn::AlgorithmConfig algorithm, se::Stream* stream,
-    se::dnn::ProfileResult* profile_result = nullptr);
+Status RunCudnnConvolution(CudnnConvParams params,
+                           se::ScratchAllocator* scratch_allocator,
+                           se::Stream* stream,
+                           se::dnn::ProfileResult* profile_result = nullptr);
 
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
index 743035a..02a0d02 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc
@@ -21,8 +21,9 @@
 
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/types.h"
 
 namespace xla {
@@ -198,11 +199,12 @@
     // All kernels are launched on a single stream, so there's no loss of
     // concurrency by optimizing for minimal memory usage.
     TF_ASSIGN_OR_RETURN(
-        schedule->thunk_launch_order_,
-        ScheduleOneComputation(
+        HloInstructionSequence sequence,
+        ScheduleComputation(
             *entry_computation, [pointer_size](const BufferValue& buffer) {
               return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
             }));
+    schedule->thunk_launch_order_ = sequence.instructions();
   } else {
     // BFS tends to increase concurrency, but also increases memory usage.
     BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
index 30a0e7c..07a7fc6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h
@@ -33,7 +33,9 @@
 // launches, because thunks may be scheduled onto concurrent streams. This
 // schedule is used by BufferAssigner to determine buffer liveness (i.e. to
 // minimize allocations), and also by ThunkSchedule to determine the thunk
-// launch order.
+// launch order. This class differs from xla::HloSchedule in that HloSchedule
+// represents a total order of all instructions in the module for backends which
+// execute HLO instructions strictly sequentially.
 class GpuHloSchedule {
  public:
   // Constructs an GpuHloSchedule for the given module, based on the given
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0922e44..b857fa7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -24,13 +24,14 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/types.h"
 
 namespace xla {
 namespace gpu {
 
-class GpuHloScheduleTest : public HloTestBase {
+class GpuHloScheduleTest : public HloVerifiedTestBase {
  protected:
   using HloVec = std::vector<const HloInstruction*>;
 
@@ -73,10 +74,10 @@
       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
-  HloInstruction* dot1 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
-  HloInstruction* dot2 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+  HloInstruction* dot1 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+  HloInstruction* dot2 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build(dot2));
@@ -201,12 +202,12 @@
       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
-  HloInstruction* dot1 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
-  HloInstruction* dot2 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
-  HloInstruction* add = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
+  HloInstruction* dot1 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+  HloInstruction* dot2 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
+  HloInstruction* add =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build(add));
@@ -269,23 +270,23 @@
         i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
   }
   HloInstruction* d00 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
-  HloInstruction* d10 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
-  HloInstruction* d11 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
-  HloInstruction* d20 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
-  HloInstruction* d21 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
-  HloInstruction* d22 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
-  HloInstruction* d30 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
-  HloInstruction* d31 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
-  HloInstruction* d40 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+      CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+  HloInstruction* d10 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+  HloInstruction* d11 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+  HloInstruction* d20 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+  HloInstruction* d21 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+  HloInstruction* d22 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+  HloInstruction* d30 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+  HloInstruction* d31 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+  HloInstruction* d40 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
index 0a4089d..27a4d0b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc
@@ -16,7 +16,7 @@
 #include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/core/lib/core/error_codes.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 
@@ -25,7 +25,7 @@
 
 using ::testing::HasSubstr;
 
-class GpuHloSupportCheckerTest : public HloTestBase {
+class GpuHloSupportCheckerTest : public HloVerifiedTestBase {
  protected:
   GpuHloSupportChecker& checker() { return checker_; }
 
@@ -45,7 +45,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  TF_ASSERT_OK(checker().Run(module.get()).status());
+  TF_ASSERT_OK(checker().Run(module).status());
 }
 
 TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) {
@@ -60,7 +60,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  Status status = checker().Run(module.get()).status();
+  Status status = checker().Run(module).status();
   ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
   EXPECT_THAT(status.error_message(),
               HasSubstr("GPU backend does not support"));
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index bca775c..96bfe0c 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/util.h"
 
 namespace op = xla::testing::opcode_matchers;
@@ -111,8 +112,8 @@
   HloComputation::Builder builder(TestName());
   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
       0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
-  auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
-      ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+  auto dot1 = builder.AddInstruction(
+      CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
   auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
       ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
 
@@ -128,8 +129,8 @@
   HloComputation::Builder builder(TestName());
   auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
       0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
-  auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
-      ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+  auto dot1 = builder.AddInstruction(
+      CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
   auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
       ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 20d523a..22f43bc0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,6 +20,7 @@
 
 #include "llvm/IR/Module.h"
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -287,5 +288,42 @@
       value->getType());
 }
 
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+                               CudnnConvParams* params) {
+  TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+                      custom_call->backend_config<CudnnConvBackendConfig>());
+  const auto& target = custom_call->custom_call_target();
+  const auto& lhs_shape = custom_call->operand(0)->shape();
+  const auto& rhs_shape = custom_call->operand(1)->shape();
+  const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
+
+  params->window = &custom_call->window();
+  params->dnums = &custom_call->convolution_dimension_numbers();
+  params->feature_group_count = custom_call->feature_group_count();
+  params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+      backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+
+  if (target == kCudnnConvForwardCallTarget) {
+    params->kind = CudnnConvKind::kForward;
+    params->input_shape = &lhs_shape;
+    params->filter_shape = &rhs_shape;
+    params->output_shape = &conv_result_shape;
+  } else if (target == kCudnnConvBackwardInputCallTarget) {
+    params->kind = CudnnConvKind::kBackwardInput;
+    params->input_shape = &conv_result_shape;
+    params->filter_shape = &rhs_shape;
+    params->output_shape = &lhs_shape;
+  } else if (target == kCudnnConvBackwardFilterCallTarget) {
+    params->kind = CudnnConvKind::kBackwardFilter;
+    params->input_shape = &lhs_shape;
+    params->filter_shape = &conv_result_shape;
+    params->output_shape = &rhs_shape;
+  } else {
+    LOG(FATAL) << "Unexpected custom call target: "
+               << custom_call->custom_call_target();
+  }
+  return Status::OK();
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 59c65fc..09c455c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,9 @@
 
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 
 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
 // don't belong in "ir_emission_utils".
@@ -148,6 +150,11 @@
 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
                                      llvm::IRBuilder<>* builder);
 
+// Populates params using conv, which must be a custom-call to a cudnn
+// convolution.  Does not modify any buffers in the params.
+Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
+                               CudnnConvParams* params);
+
 }  // namespace gpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index ffca5d6..b7c37bc 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -764,5 +764,20 @@
   return Load(return_buffer);
 }
 
+std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
+    const HloInstruction& hlo) {
+  std::vector<llvm_ir::IrArray> output_arrays;
+  if (ShapeUtil::IsTuple(hlo.shape())) {
+    int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+    output_arrays.reserve(num_outputs);
+    for (int64 i = 0; i < num_outputs; ++i) {
+      output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+    }
+  } else {
+    output_arrays.push_back(GetIrArray(hlo, hlo));
+  }
+  return output_arrays;
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 579268f..8805201 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -124,6 +124,12 @@
   llvm::Value* GetBasePointer(const HloInstruction& inst) const {
     return bindings_.GetBasePointer(inst);
   }
+
+  // Generates the IrArray for each output of an hlo instruction and returns
+  // a vector containing such IrArrays.
+  std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
+      const HloInstruction& hlo);
+
   // A convenient helper for calling BufferAssignment::GetUniqueSlice.
   BufferAllocation::Slice GetAllocationSlice(
       const HloInstruction& hlo, const ShapeIndex& index = {}) const {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
index 5c827e5..66c65f6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -119,21 +119,11 @@
   // For MOF we give the loop emitter an array for every output it should
   // generate.
   if (hlo.IsMultiOutputFusion()) {
-    const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape());
-    std::vector<llvm_ir::IrArray> target_arrays;
-    target_arrays.reserve(num_elems);
-    for (int64 i = 0; i != num_elems; ++i) {
-      target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
-    }
+    std::vector<llvm_ir::IrArray> target_arrays =
+        ConstructIrArrayForOutputs(hlo);
     TF_RETURN_IF_ERROR(
         llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
-
-    std::vector<llvm::Value*> tuple_operand_ptrs;
-    tuple_operand_ptrs.reserve(num_elems);
-    for (const llvm_ir::IrArray& array : target_arrays) {
-      tuple_operand_ptrs.push_back(array.GetBasePointer());
-    }
-    llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+    llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
     return Status::OK();
   }
   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 389a98f..b669881 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -61,6 +61,7 @@
 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -464,67 +465,35 @@
 
   if (IsCustomCallToDnnConvolution(*custom_call)) {
     const auto& assn = ir_emitter_context_->buffer_assignment();
-    const auto& lhs_shape = custom_call->operand(0)->shape();
-    const auto& rhs_shape = custom_call->operand(1)->shape();
-    const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
     auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
     auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
     auto tuple_result_slice = GetAllocationSlice(*custom_call);
     auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
     auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
 
-    TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
-                        custom_call->backend_config<CudnnConvBackendConfig>());
     const auto& target = custom_call->custom_call_target();
-    std::unique_ptr<ConvolutionThunk> thunk;
+    BufferAllocation::Slice input_slice, filter_slice, output_slice;
+
     if (target == kCudnnConvForwardCallTarget) {
-      thunk = absl::make_unique<ConvolutionThunk>(
-          CudnnConvKind::kForward,
-          /*input_buffer=*/lhs_slice,
-          /*filter_buffer=*/rhs_slice,
-          /*output_buffer=*/conv_result_slice,
-          /*tuple_result_buffer=*/tuple_result_slice,
-          /*scratch_buffer=*/scratch_slice,
-          /*input_shape=*/lhs_shape,
-          /*filter_shape=*/rhs_shape,
-          /*output_shape=*/conv_result_shape,  //
-          custom_call->window(), custom_call->convolution_dimension_numbers(),
-          custom_call->feature_group_count(), backend_config.algorithm(),
-          backend_config.tensor_ops_enabled(), custom_call);
+      input_slice = lhs_slice;
+      filter_slice = rhs_slice;
+      output_slice = conv_result_slice;
     } else if (target == kCudnnConvBackwardInputCallTarget) {
-      thunk = absl::make_unique<ConvolutionThunk>(
-          CudnnConvKind::kBackwardInput,
-          /*input_buffer=*/conv_result_slice,
-          /*filter_buffer=*/rhs_slice,
-          /*output_buffer=*/lhs_slice,
-          /*tuple_result_buffer=*/tuple_result_slice,
-          /*scratch_buffer=*/scratch_slice,
-          /*input_shape=*/conv_result_shape,
-          /*filter_shape=*/rhs_shape,
-          /*output_shape=*/lhs_shape,  //
-          custom_call->window(), custom_call->convolution_dimension_numbers(),
-          custom_call->feature_group_count(), backend_config.algorithm(),
-          backend_config.tensor_ops_enabled(), custom_call);
+      input_slice = conv_result_slice;
+      filter_slice = rhs_slice;
+      output_slice = lhs_slice;
     } else if (target == kCudnnConvBackwardFilterCallTarget) {
-      thunk = absl::make_unique<ConvolutionThunk>(
-          CudnnConvKind::kBackwardFilter,
-          /*input_buffer=*/lhs_slice,
-          /*filter_buffer=*/conv_result_slice,
-          /*output_buffer=*/rhs_slice,
-          /*tuple_result_buffer=*/tuple_result_slice,
-          /*scratch_buffer=*/scratch_slice,
-          /*input_shape=*/lhs_shape,
-          /*filter_shape=*/conv_result_shape,
-          /*output_shape=*/rhs_shape,  //
-          custom_call->window(), custom_call->convolution_dimension_numbers(),
-          custom_call->feature_group_count(), backend_config.algorithm(),
-          backend_config.tensor_ops_enabled(), custom_call);
+      input_slice = lhs_slice;
+      filter_slice = conv_result_slice;
+      output_slice = rhs_slice;
     } else {
       LOG(FATAL) << "Unexpected custom call target: "
                  << custom_call->custom_call_target();
     }
 
-    thunk_sequence_->emplace_back(std::move(thunk));
+    thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
+        Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
+        output_slice, scratch_slice, tuple_result_slice));
     return Status::OK();
   }
 
@@ -2521,15 +2490,15 @@
 }
 
 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
-    const HloInstruction* hlo, const ShapeIndex& index) {
+    HloInstruction* hlo, const ShapeIndex& index) {
   bool fused = HloOpcode::kFusion == hlo->opcode();
-  const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
-  const HloInstruction* init_value_operand = [&] {
+  HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+  HloInstruction* init_value_operand = [&] {
     switch (inst->opcode()) {
       case HloOpcode::kSelectAndScatter:
-        return inst->operand(2);
+        return inst->mutable_operand(2);
       case HloOpcode::kReduce:
-        return inst->operand(1);
+        return inst->mutable_operand(1);
       case HloOpcode::kTuple:
         CHECK(hlo->IsMultiOutputFusion())
             << ": " << hlo->ToString() << " is not a multi-output fusion.";
@@ -2537,7 +2506,7 @@
             << ": Found '" << inst->operand(index.back())->opcode() << "' in "
             << inst->ToString() << " but expected 'reduce'.";
         // For multi-output fusion look through the tuple.
-        return inst->operand(index.back())->operand(1);
+        return inst->mutable_operand(index.back())->mutable_operand(1);
       default:
         LOG(FATAL) << "Opcode " << inst->opcode()
                    << " should not need an initializer.";
@@ -2609,28 +2578,35 @@
                                 ir_emitter_context_->device_description());
   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
                          ir_emitter_context_->llvm_module());
-  // If the init_value was fused into this reduce we have to generate it first.
-  if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
-    CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
 
-    const Literal& literal = init_value_operand->literal();
-    llvm::Constant* initializer =
-        llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+  if (fused) {
+    // If init_value was fused into this reduce we have to generate it first.
+    std::vector<IrArray> parameter_arrays;
+    for (HloInstruction* operand : hlo->operands()) {
+      parameter_arrays.push_back(GetIrArray(*operand, *hlo));
+    }
+    GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+                                            ir_emitter_context_->llvm_module(),
+                                            &b_, GetNestedComputer());
 
-    llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
-        *module_, initializer->getType(),
-        /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
-        /*Name=*/"");
-    global_for_const->setAlignment(kConstantBufferAlignBytes);
-    bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
+    FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+    TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
+    TF_RETURN_IF_ERROR(
+        ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
+                            GetIrArray(*hlo, *hlo, index), launch_dimensions,
+                            &b_)
+            .EmitLoop(IrName(hlo)));
+  } else {
+    // In the unfused case the element is already there, just read from it.
+    TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+                           [=](const IrArray::Index& index) {
+                             return GetIrArray(*init_value, *hlo)
+                                 .EmitReadArrayElement(index, &b_);
+                           },
+                           GetIrArray(*hlo, *hlo, index), launch_dimensions,
+                           &b_)
+                           .EmitLoop(IrName(hlo)));
   }
-  TF_RETURN_IF_ERROR(ParallelLoopEmitter(
-                         [=](const IrArray::Index& index) {
-                           return GetIrArray(*init_value, *hlo)
-                               .EmitReadArrayElement(index, &b_);
-                         },
-                         GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
-                         .EmitLoop(IrName(hlo)));
 
   // Clean up state left behind by emitting the loop above.  (This is normally
   // done in IrEmitterUnnested::Postprocess().)
@@ -2819,10 +2795,7 @@
   }
 
   // For multioutput fusion, we need to emit each operand and the root.
-  std::vector<IrArray> output_arrays;
-  for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
-    output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
-  }
+  std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
   TF_RETURN_IF_ERROR(
       ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
                           &b_, unroll_factor)
@@ -2830,12 +2803,9 @@
                     GetIndexTypeForKernel(
                         &hlo, launch_dimensions.launch_bound(), &b_)));
 
-  std::vector<llvm::Value*> tuple_operand_ptrs;
-  for (int64 i = 0; i < output_arrays.size(); ++i) {
-    tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
-  }
   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
-  llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+  llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
+
   return Status::OK();
 }
 
@@ -2847,29 +2817,14 @@
                                       static_cast<KernelThunk*>(LastThunk()));
 }
 
-int IrEmitterUnnested::ConstructIrArrayForOutputs(
-    const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
-  int64 num_outputs = 1;
-  if (hlo.IsMultiOutputFusion()) {
-    num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
-    output_arrays->reserve(num_outputs);
-    for (int64 i = 0; i < num_outputs; ++i) {
-      output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
-    }
-  } else {
-    output_arrays->push_back(GetIrArray(hlo, hlo));
-  }
-  return num_outputs;
-}
-
-int IrEmitterUnnested::ConstructIrArrayForInputs(
-    const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
-  int64 num_params = hlo.operands().size();
-  param_arrays->reserve(num_params);
+std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
+    const HloInstruction& hlo) {
+  std::vector<IrArray> param_arrays;
+  param_arrays.reserve(hlo.operands().size());
   for (const HloInstruction* param : hlo.operands()) {
-    param_arrays->push_back(GetIrArray(*param, hlo));
+    param_arrays.push_back(GetIrArray(*param, hlo));
   }
-  return num_params;
+  return param_arrays;
 }
 
 int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
@@ -3050,10 +3005,10 @@
   constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
 
   // Construct IrArrays for the inputs and outputs.
-  std::vector<IrArray> output_arrays;
-  int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
-  std::vector<IrArray> param_arrays;
-  int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+  std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
+  int64 num_outputs = output_arrays.size();
+  std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo);
+  int64 num_params = param_arrays.size();
 
   // Allocate shared memory buffers to store the tiled inputs.
   std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
@@ -3251,12 +3206,7 @@
 
   // For multioutput fusion, emit a tuple with all the individual outputs.
   if (hlo->IsMultiOutputFusion()) {
-    std::vector<llvm::Value*> tuple_operand_ptrs;
-    for (int64 i = 0; i < output_arrays.size(); ++i) {
-      tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
-    }
-    llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
-                       module_);
+    llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_);
   }
 
   return launch_dimensions;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 0844623..bd5db72 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -193,14 +193,12 @@
   LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
                                   absl::Span<const int64> reduced_output_dims,
                                   absl::Span<const int64> tiled_param_ids);
-  // Generates the IrArray for each output of hlo and returns the number of
-  // outputs.
-  int ConstructIrArrayForOutputs(const HloInstruction& hlo,
-                                 std::vector<llvm_ir::IrArray>* output_arrays);
-  // Generates the IrArray for each input of hlo and returns the number of
-  // inputs.
-  int ConstructIrArrayForInputs(const HloInstruction& hlo,
-                                std::vector<llvm_ir::IrArray>* param_arrays);
+
+  // Generates the IrArray for each input of an hlo and returns a vector that
+  // constains such IrArrays.
+  std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
+      const HloInstruction& hlo);
+
   // For each output of the `hlo` instruction, constructs the reduced shape for
   // the output with the given `reduced_output_dims` and cast the original
   // output IrArray element in `output_arrays` to the reduced shape. Returns
@@ -244,7 +242,7 @@
   // Returns a thunk that, given a reduce or select-and-scatter op, initializes
   // its memory to the appropriate initial value.
   StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
-      const HloInstruction* hlo, const ShapeIndex& index = {});
+      HloInstruction* hlo, const ShapeIndex& index = {});
 
   // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
   std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index c822c94..8a6e532 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -259,7 +259,7 @@
 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) {
   // Fusing a reduce into a loop fusion would require changing the fusion kind.
   // That's not supported yet.
-  auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+  auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
     fused_computation_1 {
       p0.1 = f32[6400]{0} parameter(0)
       ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
@@ -277,7 +277,7 @@
 }
 
 TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) {
-  auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+  auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
     fused_computation_1 {
       p0.1 = f32[6400]{0} parameter(0)
       ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
@@ -301,7 +301,7 @@
 }
 
 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) {
-  auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+  auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
     fused_computation_1 {
       p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
       ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
@@ -324,7 +324,7 @@
 }
 
 TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) {
-  auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+  auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
     fused_computation_1 {
       p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
       mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
@@ -358,7 +358,7 @@
 
 TEST_F(MultiOutputFusionTest,
        MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) {
-  auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+  auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
     fused_computation_1 {
       p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0)
       mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1)
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index f6325b3..dfdcf18 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -208,10 +208,6 @@
     pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
                                               /*allow_mixed_precision=*/false);
     pipeline.AddPass<CudnnConvolutionRewriter>();
-    // CudnnConvolutionRewriter may add instructions of the form
-    // reverse(constant), which it expects will be simplified by constant
-    // folding.
-    pipeline.AddPass<HloConstantFolding>();
     pipeline.AddPass<PadInsertion>();
     if (IsVoltaOrLater(*stream_exec)) {
       pipeline.AddPass<PadForTensorCores>();
@@ -219,6 +215,9 @@
       // pairs that TupleSimplifier fixes.
       pipeline.AddPass<TupleSimplifier>();
     }
+    // CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
+    // instructions which can be simplified by constant folding.
+    pipeline.AddPass<HloConstantFolding>();
     TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
   }
 
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index fa84d77..b0061fa 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -23,7 +23,6 @@
 namespace xla {
 namespace gpu {
 
-
 // We want the input/output feature counts of an f16 conv to be factors of 8,
 // because without this cudnn can't use tensor cores on the conv.
 static constexpr int64 kDesiredNumFeaturesFactor = 8;
@@ -63,8 +62,8 @@
   HloComputation* comp = instr->parent();
 
   const Shape& shape = instr->shape();
-  auto* zero = comp->AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::Zero(shape.element_type()).CloneToUnique()));
+  auto* zero = comp->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(shape.element_type())));
 
   PaddingConfig pad_config = MakeNoPaddingConfig(ShapeUtil::Rank(shape));
 
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 9d85d74..2a6415d 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -68,9 +68,8 @@
           conv_window.dimensions(i).base_dilation() - 1);
     }
     PrimitiveType element_type = input->shape().element_type();
-    HloInstruction* padding =
-        computation->AddInstruction(HloInstruction::CreateConstant(
-            absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+    HloInstruction* padding = computation->AddInstruction(
+        HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
     input = MakePadHlo(input, padding, padding_config).ValueOrDie();
   }
 
@@ -125,9 +124,8 @@
 
   HloComputation* computation = kernel->parent();
   PrimitiveType element_type = kernel->shape().element_type();
-  HloInstruction* padding =
-      computation->AddInstruction(HloInstruction::CreateConstant(
-          absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+  HloInstruction* padding = computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
   return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
 }
 }  // namespace
@@ -236,9 +234,9 @@
   // Create a new backward convolution replacing the old one.
   HloComputation* computation = backward_conv->parent();
   HloInstruction* output = backward_conv->mutable_operand(1);
-  HloInstruction* padding = computation->AddInstruction(
-      HloInstruction::CreateConstant(absl::make_unique<Literal>(
-          LiteralUtil::Zero(input->shape().element_type()))));
+  HloInstruction* padding =
+      computation->AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::Zero(input->shape().element_type())));
   HloInstruction* padded_input =
       MakePadHlo(input, padding, input_padding_config).ValueOrDie();
 
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 091aca2..c4f43cc 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -21,13 +21,14 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/types.h"
 
 namespace xla {
 namespace gpu {
 
-class StreamAssignmentTest : public HloTestBase {
+class StreamAssignmentTest : public HloVerifiedTestBase {
  protected:
   std::unique_ptr<HloModule> CreateNewModule() {
     HloModuleConfig config;
@@ -49,10 +50,10 @@
       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
   HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
-  HloInstruction* dot1 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
-  HloInstruction* dot2 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+  HloInstruction* dot1 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+  HloInstruction* dot2 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build(dot2));
@@ -68,10 +69,10 @@
       /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
   HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
       /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
-  HloInstruction* dot1 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
-  HloInstruction* dot2 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
+  HloInstruction* dot1 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+  HloInstruction* dot2 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
   HloInstruction* add = builder.AddInstruction(
       HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
 
@@ -101,23 +102,23 @@
         i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
   }
   HloInstruction* d00 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
-  HloInstruction* d10 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
-  HloInstruction* d11 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
-  HloInstruction* d20 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
-  HloInstruction* d21 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
-  HloInstruction* d22 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
-  HloInstruction* d30 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
-  HloInstruction* d31 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
-  HloInstruction* d40 = builder.AddInstruction(
-      HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+      CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+  HloInstruction* d10 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+  HloInstruction* d11 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+  HloInstruction* d20 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+  HloInstruction* d21 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+  HloInstruction* d22 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+  HloInstruction* d30 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+  HloInstruction* d31 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+  HloInstruction* d40 =
+      builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index 4550f36..780539c 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -38,8 +38,7 @@
 TEST_F(GpuCopyTest, UseMemcpy) {
   HloComputation::Builder builder(TestName());
 
-  std::unique_ptr<Literal> literal =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   HloInstruction* constant = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(literal)));
   builder.AddInstruction(HloInstruction::CreateUnary(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index 9072b30..f8120a5 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -53,40 +53,40 @@
 };
 
 TEST_F(InfeedTest, SingleInfeedR0Bool) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
+  TestInfeedRoundTrip(LiteralUtil::CreateR0<bool>(true));
 }
 
 TEST_F(InfeedTest, SingleInfeedR1U32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+  TestInfeedRoundTrip(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 }
 
 TEST_F(InfeedTest, SingleInfeedR2F32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+  TestInfeedRoundTrip(LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
 }
 
 TEST_F(InfeedTest, SingleInfeedR3F32) {
   TestInfeedRoundTrip(
-      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
-                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+      LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+                             {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 }
 
 TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
   const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
   const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
 
-  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
       {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
        {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
       r3_dim0minor));
 
-  TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+  TestInfeedRoundTrip(LiteralUtil::CreateR3WithLayout(
       {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
        {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
       r3_dim0major));
 }
 
 TEST_F(InfeedTest, SingleInfeedR4S32) {
-  TestInfeedRoundTrip(*LiteralUtil::CreateR4(
+  TestInfeedRoundTrip(LiteralUtil::CreateR4(
       {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
        {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 }
@@ -95,26 +95,26 @@
 TEST_F(InfeedTest, LargeInfeed) {
   Array4D<float> array(80, 100, 8, 128);
   array.FillIota(1.0f);
-  TestInfeedRoundTrip(*LiteralUtil::CreateR4FromArray4D<float>(array));
+  TestInfeedRoundTrip(LiteralUtil::CreateR4FromArray4D<float>(array));
 }
 
 TEST_F(InfeedTest, SingleInfeedTuple) {
-  TestInfeedRoundTrip(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
-                               LiteralUtil::CreateR0<bool>(false).get()}));
+  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<uint32>({1, 2, 3}),
+       LiteralUtil::CreateR0<bool>(false)}));
 }
 
 TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
-  TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
+  TestInfeedRoundTrip(LiteralUtil::MakeTuple({}));
 }
 
 // Tests that a large tuple infeed can be handled.
 TEST_F(InfeedTest, SingleInfeedLargeTuple) {
   Array4D<float> array(40, 100, 8, 128);
   array.FillIota(1.0f);
-  TestInfeedRoundTrip(*LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR4FromArray4D<float>(array).get(),
-       LiteralUtil::CreateR0<int32>(5).get()}));
+  TestInfeedRoundTrip(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR4FromArray4D<float>(array),
+       LiteralUtil::CreateR0<int32>(5)}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 40183de..9a61f8a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -26,9 +26,6 @@
 namespace xla {
 namespace {
 
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
 class WhileTransformerTest : public HloTestBase {
  protected:
   WhileTransformerTest()
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index a2be895..ef70b68 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -112,8 +112,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      /*new_size=*/2, PrecisionConfig::DEFAULT);
+  auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+      vshape, clamp, param_v0, dot_dnums, precision_config));
   auto tuple = builder.AddInstruction(
       HloInstruction::CreateTuple({dot, param_s, clamp}));
   auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 38c3982..e0f3a7e 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -29,13 +29,13 @@
 
 /*static*/
 StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
-    const SequentialHloOrdering::HloModuleSequence& module_sequence,
+    const HloSchedule& schedule,
     const LogicalBuffer::SizeFunction& size_function) {
-  if (module_sequence.empty()) {
+  if (schedule.empty()) {
     return 0;
   }
 
-  const HloModule* module = module_sequence.begin()->first->parent();
+  const HloModule* module = schedule.module();
   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
                       TuplePointsToAnalysis::Run(module));
 
@@ -47,14 +47,13 @@
   TF_ASSIGN_OR_RETURN(
       HeapSimulator::Result result,
       HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module,
-                         module_sequence, *points_to_analysis, size_function));
+                         schedule, *points_to_analysis, size_function));
   return result.heap_size;
 }
 
 /*static*/
 StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
-    const HloComputation& computation,
-    const std::vector<const HloInstruction*>& sequence,
+    const HloComputation& computation, const HloInstructionSequence& sequence,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
     const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
@@ -71,13 +70,13 @@
 /*static*/
 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
     std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
-    const SequentialHloOrdering::HloModuleSequence& module_sequence,
+    const HloSchedule& schedule,
     const TuplePointsToAnalysis& points_to_analysis,
     const BufferValue::SizeFunction& size_fn, const Options& options) {
-  HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence);
+  HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule);
   const HloComputation* entry_computation = module.entry_computation();
-  const std::vector<const HloInstruction*>& instruction_sequence =
-      FindOrDie(module_sequence, entry_computation);
+  const HloInstructionSequence& instruction_sequence =
+      schedule.sequence(entry_computation);
   TF_RETURN_IF_ERROR(heap.RunComputation(
       *entry_computation, instruction_sequence, points_to_analysis));
   return heap.Finish();
@@ -86,13 +85,13 @@
 /*static*/
 StatusOr<HeapSimulator::Result> HeapSimulator::Run(
     std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
-    const std::vector<const HloInstruction*>& instruction_sequence,
+    const HloInstructionSequence& instruction_sequence,
     const TuplePointsToAnalysis& points_to_analysis,
     const BufferValue::SizeFunction& size_fn, const Options& options,
     const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
         memory_by_computation) {
   HeapSimulator heap(std::move(algorithm), size_fn, options,
-                     /*module_sequence=*/nullptr, memory_by_computation);
+                     /*schedule=*/nullptr, memory_by_computation);
   TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
                                          points_to_analysis));
   return heap.Finish();
@@ -102,7 +101,7 @@
 // 'instruction_sequence'.
 Status HeapSimulator::RunComputation(
     const HloComputation& computation,
-    const std::vector<const HloInstruction*>& instruction_sequence,
+    const HloInstructionSequence& instruction_sequence,
     const TuplePointsToAnalysis& points_to_analysis) {
   VLOG(3) << "Computation:\n" << computation.ToString();
   // The goal here is to minimize memory usage, assuming the given sequential
@@ -133,7 +132,8 @@
   // set of instructions that need to be visited contains all users of all
   // aliases, that is, all users of all instructions that have the buffer
   // contained in their points-to set.
-  for (const HloInstruction* instruction : instruction_sequence) {
+  for (const HloInstruction* instruction :
+       instruction_sequence.instructions()) {
     const PointsToSet& points_to =
         points_to_analysis.GetPointsToSet(instruction);
     const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet();
@@ -166,7 +166,8 @@
 
   std::vector<const BufferValue*> dead_buffers_to_free;
   std::vector<const BufferValue*> operand_buffers_to_free;
-  for (const HloInstruction* instruction : instruction_sequence) {
+  for (const HloInstruction* instruction :
+       instruction_sequence.instructions()) {
     const TuplePointsToAnalysis::BufferDefinitionVector&
         buffers_defined_by_instruction =
             points_to_analysis.GetBuffersDefinedByInstruction(instruction);
@@ -285,14 +286,14 @@
     // The order that the sub-computations are simulated does not affect
     // correctness; since the whole module has been scheduled, we know that the
     // sub-computations will never be run concurrently.
-    if (module_sequence_ != nullptr) {
+    if (schedule_ != nullptr) {
       if (instruction->opcode() == HloOpcode::kCall ||
           instruction->opcode() == HloOpcode::kConditional ||
           instruction->opcode() == HloOpcode::kWhile) {
         for (const HloComputation* called_computation :
              instruction->called_computations()) {
-          const std::vector<const HloInstruction*>& called_sequence =
-              FindOrDie(*module_sequence_, called_computation);
+          const HloInstructionSequence& called_sequence =
+              schedule_->sequence(called_computation);
           TF_RETURN_IF_ERROR(RunComputation(
               *called_computation, called_sequence, points_to_analysis));
         }
@@ -343,16 +344,16 @@
 HeapSimulator::HeapSimulator(
     std::unique_ptr<HeapAlgorithm> algorithm,
     const BufferValue::SizeFunction& size_fn, const Options& options,
-    const SequentialHloOrdering::HloModuleSequence* module_sequence,
+    const HloSchedule* schedule,
     const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
         memory_by_computation)
     : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
       algorithm_(std::move(algorithm)),
       size_fn_(size_fn),
       options_(options),
-      module_sequence_(module_sequence),
+      schedule_(schedule),
       memory_by_computation_(memory_by_computation) {
-  debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
+  debug_trace_.set_whole_module_simulation(schedule_ != nullptr);
 }
 
 HeapSimulator::~HeapSimulator() {}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index af05bed..ffbf947 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -27,6 +27,7 @@
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
@@ -88,23 +89,22 @@
 
   // Returns the minimum memory required to compute an HLO module where all
   // computations have been scheduled (represented by the given
-  // module_sequence), assuming no fragmentation.
+  // schedule), assuming no fragmentation.
   static StatusOr<int64> MinimumMemoryForModule(
-      const SequentialHloOrdering::HloModuleSequence& module_sequence,
+      const HloSchedule& schedule,
       const LogicalBuffer::SizeFunction& size_function);
 
   // Returns the minimum memory required to compute the given computation,
   // assuming no fragmentation.
   static StatusOr<int64> MinimumMemoryForComputation(
-      const HloComputation& computation,
-      const std::vector<const HloInstruction*>& sequence,
+      const HloComputation& computation, const HloInstructionSequence& sequence,
       const TuplePointsToAnalysis& points_to_analysis,
       const LogicalBuffer::SizeFunction& size_function,
       const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
           memory_by_computation = nullptr);
 
   // Run the heap simulation with the given algorithm, assuming the given
-  // module_sequence, which must contain a topologically-consistent total
+  // schedule, which must contain a topologically-consistent total
   // ordering of all instructions within each computation. The result is invalid
   // if instructions are not run in exactly this sequence.
   //
@@ -112,12 +112,12 @@
   // to running on a per-computation basis, since we can re-use buffer space for
   // called sub-computations.
   //
-  static StatusOr<Result> Run(
-      std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
-      const SequentialHloOrdering::HloModuleSequence& module_sequence,
-      const TuplePointsToAnalysis& points_to_analysis,
-      const BufferValue::SizeFunction& size_fn,
-      const Options& options = Options());
+  static StatusOr<Result> Run(std::unique_ptr<HeapAlgorithm> algorithm,
+                              const HloModule& module,
+                              const HloSchedule& schedule,
+                              const TuplePointsToAnalysis& points_to_analysis,
+                              const BufferValue::SizeFunction& size_fn,
+                              const Options& options = Options());
 
   // Same as above, but runs on a single computation. The 'instruction_sequence'
   // must contain a topologically-consistent total ordering of all instructions
@@ -126,7 +126,7 @@
   static StatusOr<Result> Run(
       std::unique_ptr<HeapAlgorithm> algorithm,
       const HloComputation& computation,
-      const std::vector<const HloInstruction*>& instruction_sequence,
+      const HloInstructionSequence& instruction_sequence,
       const TuplePointsToAnalysis& points_to_analysis,
       const BufferValue::SizeFunction& size_fn,
       const Options& options = Options(),
@@ -134,21 +134,19 @@
           memory_by_computation = nullptr);
 
  private:
-  // If 'module_sequence' is non-null, it is used to find kCall and kWhile
+  // If 'schedule' is non-null, it is used to find kCall and kWhile
   // sub-computations, and the heap simulation for those sub-computations will
   // be run recursively. I.e. the simulation is run over the whole module.
-  HeapSimulator(
-      std::unique_ptr<HeapAlgorithm> algorithm,
-      const BufferValue::SizeFunction& size_fn, const Options& options,
-      const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
-      const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
-          memory_by_computation = nullptr);
+  HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
+                const BufferValue::SizeFunction& size_fn,
+                const Options& options, const HloSchedule* schedule = nullptr,
+                const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+                    memory_by_computation = nullptr);
   ~HeapSimulator();
 
-  Status RunComputation(
-      const HloComputation& computation,
-      const std::vector<const HloInstruction*>& instruction_sequence,
-      const TuplePointsToAnalysis& points_to_analysis);
+  Status RunComputation(const HloComputation& computation,
+                        const HloInstructionSequence& instruction_sequence,
+                        const TuplePointsToAnalysis& points_to_analysis);
 
   bool IgnoreBuffer(const BufferValue* buffer) const;
   void Alloc(const BufferValue* buffer, const HloInstruction* instruction);
@@ -169,11 +167,11 @@
   const std::unique_ptr<HeapAlgorithm> algorithm_;
   const BufferValue::SizeFunction size_fn_;
   const Options options_;
-  // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+  // schedule_ is set by buffer assignment, and memory_by_computation_ is
   // set by hlo scheduling. Then, in RunComputation, we check both in order to
   // handle subcomputations. It would be good to unify the handling of
   // subcomputations, but it's not clear how.
-  const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+  const HloSchedule* schedule_;
   const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
       memory_by_computation_;
 
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 5f85f14..957c4a6 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -29,13 +29,14 @@
 #include "tensorflow/compiler/xla/service/hlo_value.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
 #include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
 
 namespace xla {
 namespace {
 
-class MinimumMemoryForSequenceTest : public HloTestBase {};
+class MinimumMemoryForSequenceTest : public HloVerifiedTestBase {};
 
 TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
   auto module = CreateNewModule();
@@ -85,13 +86,16 @@
     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
   };
 
-  SequentialHloOrdering::HloModuleSequence module_sequence;
-  module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
-                                       cond_lt};
-  module_sequence[body_computation] = {body_param};
-  module_sequence[entry_computation] = {iter, data, tuple, while_op};
-  EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
-                    .ValueOrDie());
+  HloSchedule schedule(module);
+  schedule.set_sequence(cond_computation,
+                        {cond_param, cond_iter, cond_data, cond_lt});
+  schedule.set_sequence(body_computation, {body_param});
+  schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(
+      56,
+      HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie());
 }
 
 const char kAlloc[] = "Alloc";
@@ -149,10 +153,11 @@
     auto zero_size = [](const BufferValue& buffer) { return 0; };
     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
         absl::make_unique<HeapCallRecorder>(&actual_calls_));
-    result_ = HeapSimulator::Run(
-                  std::move(algorithm), *module_->entry_computation(),
-                  instruction_sequence, *points_to_analysis_, zero_size)
-                  .ConsumeValueOrDie();
+    result_ =
+        HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(),
+                           HloInstructionSequence(instruction_sequence),
+                           *points_to_analysis_, zero_size)
+            .ConsumeValueOrDie();
   }
 
   explicit HeapSimulatorTracker(const string& name) {
@@ -168,11 +173,12 @@
         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
 
     // Construct the module sequence grouped by computation.
-    SequentialHloOrdering::HloModuleSequence module_sequence;
+    HloSchedule schedule(module_.get());
     tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
     for (int i = 0; i < full_module_sequence.size(); ++i) {
       const HloInstruction* instruction = full_module_sequence[i];
-      module_sequence[instruction->parent()].push_back(instruction);
+      schedule.GetOrCreateSequence(instruction->parent())
+          .push_back(instruction);
       reverse_position[instruction] = full_module_sequence.size() - i;
     }
 
@@ -185,8 +191,8 @@
     };
     auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>(
         absl::make_unique<HeapCallRecorder>(&actual_calls_));
-    result_ = HeapSimulator::Run(std::move(algorithm), *module_,
-                                 module_sequence, *points_to_analysis_, size_fn)
+    result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule,
+                                 *points_to_analysis_, size_fn)
                   .ConsumeValueOrDie();
   }
 
@@ -227,7 +233,7 @@
   HeapSimulator::Result result_;
 };
 
-class HeapSimulatorTest : public HloTestBase {
+class HeapSimulatorTest : public HloVerifiedTestBase {
  protected:
   HeapSimulatorTest() {}
   ~HeapSimulatorTest() override {}
@@ -366,8 +372,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+  auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
 
   // The buffer for dot is the output, and it cannot be shared with the buffer
   // for mul, since dot isn't elementwise.
@@ -402,8 +408,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
+  auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
   auto add = builder.AddInstruction(
       HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
 
@@ -440,10 +446,10 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot0 = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
-  auto dot1 = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+  auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+  auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
 
   // The buffer for dot1 is the output.  No buffers can be shared.  The buffer
   // for mul is freed before the end, since it's no longer used after dot0
@@ -481,10 +487,10 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot0 = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
-  auto dot1 = builder.AddInstruction(
-      HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
+  auto dot0 = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2)));
+  auto dot1 = builder.AddInstruction(HloInstruction::CreateDot(
+      f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2)));
   auto tuple =
       builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
 
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 58b7af9..b19ec12 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -172,7 +172,7 @@
   xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
 
   // Precision configuration for the instruction. Has backend-specific meaning.
-  xla.PrecisionConfigProto precision_config = 51;
+  xla.PrecisionConfig precision_config = 51;
 
   // Collective permute field.
   repeated SourceTarget source_target_pairs = 52;
@@ -199,6 +199,17 @@
   int64 root_id = 6;
 }
 
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+  message InstructionSequence {
+    repeated int64 instruction_ids = 1;
+  }
+
+  // Map from computation id to sequence.
+  map<int64, InstructionSequence> sequences = 1;
+}
+
 // Serialization of HloModule.
 message HloModuleProto {
   string name = 1;
@@ -214,16 +225,9 @@
 
   // The id of this module.
   int64 id = 5;
-}
 
-// Serialization of HloOrdering.
-message HloOrderingProto {
-  // NOTE: currently only sequential orderings are serialized.
-  message SequentialComputation {
-    string computation_name = 1;
-    repeated string instruction_names = 2;
-  }
-  repeated SequentialComputation sequential_computations = 1;
+  // The schedule for this module.
+  HloScheduleProto schedule = 7;
 }
 
 // Serialization of LogicalBuffer.
@@ -305,6 +309,13 @@
   bool whole_module_simulation = 2;
 }
 
+// An abstraction representing a set of HLO module built to run concurrently
+// across different devices.
+message HloModuleGroupProto {
+  string name = 1;
+  repeated HloModuleProto hlo_modules = 2;
+}
+
 // Serialization of BufferAssignment.
 message BufferAssignmentProto {
   // Alias represents a source LogicalBuffer, and the buffer location that
@@ -322,8 +333,10 @@
 
 // Grouping message that contains all of the information above.
 message HloProto {
+  reserved 2;
+  reserved "hlo_ordering";
+
   HloModuleProto hlo_module = 1;
-  HloOrderingProto hlo_ordering = 2;
   BufferAssignmentProto buffer_assignment = 3;
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 54abe33..0cd0ab3 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -885,18 +885,20 @@
 
   // For a sequential order, if there is interference iff the negate is after
   // the while.
-  SequentialHloOrdering::HloModuleSequence sequence;
-  sequence[body] = {body_param, body_root};
-  sequence[condition] = {cond_param, cond_root};
+  HloSchedule schedule(module_);
+  schedule.set_sequence(body, {body_param, body_root});
+  schedule.set_sequence(condition, {cond_param, cond_root});
   {
-    sequence[entry] = {init, xla_while, negate, entry_root};
-    SequentialHloOrdering ordering(module_, sequence);
+    schedule.set_sequence(entry, {init, xla_while, negate, entry_root});
+    TF_ASSERT_OK(schedule.Verify());
+    SequentialHloOrdering ordering(schedule);
     EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering));
   }
 
   {
-    sequence[entry] = {init, negate, xla_while, entry_root};
-    SequentialHloOrdering ordering(module_, sequence);
+    schedule.set_sequence(entry, {init, negate, xla_while, entry_root});
+    TF_ASSERT_OK(schedule.Verify());
+    SequentialHloOrdering ordering(schedule);
     EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering));
   }
 }
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index fe7f2be..8c6903d 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -464,6 +464,14 @@
 }
 
 string HloComputation::ToString(const HloPrintOptions& options) const {
+  return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+    const HloPrintOptions& options,
+    absl::Span<const HloInstruction* const> instruction_order) const {
+  CHECK_EQ(instruction_order.size(), instruction_count());
+
   std::ostringstream s;
   for (int i = 0; i < options.indent_amount(); i++) {
     s << "  ";
@@ -486,7 +494,9 @@
     new_options.set_indent_amount(options.indent_amount() + 1)
         .set_is_in_nested_computation(true);
     CanonicalNameMap name_map;
-    for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+    for (const HloInstruction* instruction : instruction_order) {
+      CHECK_EQ(this, instruction->parent());
+
       for (int i = 0; i < new_options.indent_amount(); i++) {
         s << "  ";
       }
@@ -552,9 +562,11 @@
               return to_proto_id[a.get()] < to_proto_id[b.get()];
             });
 
-  return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
-                                             &instructions, root,
-                                             /*fusion_instruction=*/nullptr));
+  auto computation = absl::WrapUnique(
+      new HloComputation(proto.name(), parameter_count, &instructions, root,
+                         /*fusion_instruction=*/nullptr));
+  computation->unique_id_ = proto.id();
+  return std::move(computation);
 }
 
 void HloComputation::FuseInstructionsInto(
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fe2d3bb..91c5234 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -170,6 +170,11 @@
   string ToString() const { return ToString(HloPrintOptions()); }
   string ToString(const HloPrintOptions& options) const;
 
+  // Overload which accepts an order to emit the instructions in.
+  string ToString(
+      const HloPrintOptions& options,
+      absl::Span<const HloInstruction* const> instruction_order) const;
+
   // Returns a serialized representation of this computation.
   HloComputationProto ToProto() const;
 
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index f7ed1b0..2aaaef1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -601,8 +601,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
   builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
 
@@ -633,8 +636,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
   builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
 
@@ -666,8 +672,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
   builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+      HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config));
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
 
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 8a45939..f837816 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -76,10 +76,10 @@
         continue;
       }
 
-      std::unique_ptr<Literal> result = evaluator->TryEvaluate(instruction);
+      Literal result;
       // Currently we skip unimplemented operations.
       // TODO(b/35975797): Fold constant computations for more operations.
-      if (result == nullptr) {
+      if (!evaluator->TryEvaluate(instruction, &result)) {
         VLOG(2) << "Constant folding failed for instruction: "
                 << instruction->ToString();
         continue;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 07cd1ef..3e0def5 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/types.h"
 
@@ -37,7 +37,7 @@
 namespace xla {
 namespace {
 
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
 
 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
   HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@
   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
 
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
   EXPECT_TRUE(result);
 
   EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@
   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
 
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
   EXPECT_TRUE(result);
 
   EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@
   EXPECT_THAT(computation->root_instruction(), op::Convert(input));
 
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
   EXPECT_TRUE(result);
 
   EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@
     auto computation = module->AddEntryComputation(builder.Build());
 
     HloConstantFolding const_folder;
-    TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+    TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
     EXPECT_TRUE(result);
 
     HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@
   auto computation = module->AddEntryComputation(builder.Build());
 
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
   EXPECT_TRUE(result);
 
   HloInstruction* root = computation->root_instruction();
@@ -175,7 +175,7 @@
   TF_ASSERT_OK_AND_ASSIGN(auto literal,
                           LiteralUtil::CreateRandomLiteral<F32>(
                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
-  auto literal_clone = literal->Literal::CloneToUnique();
+  auto literal_clone = literal.Clone();
   HloInstruction* literal_instruction = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(literal)));
   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
@@ -186,7 +186,7 @@
   auto computation = module->AddEntryComputation(builder.Build());
 
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
   EXPECT_TRUE(result);
 
   HloInstruction* root = computation->root_instruction();
@@ -198,7 +198,7 @@
   root->literal().EachCell<NativeT>(
       [&](absl::Span<const int64> indices, NativeT value) {
         std::vector<int64> rindexes = Permute(permutation, indices);
-        matched = matched && (value == literal_clone->Get<NativeT>(rindexes));
+        matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
       });
   EXPECT_TRUE(matched);
 }
@@ -219,28 +219,27 @@
   })";
 
 TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(kConstantFoldReduce));
+  ParseAndVerifyModule(kConstantFoldReduce);
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
   EXPECT_TRUE(result);
 
-  EXPECT_EQ(6, module->entry_computation()
+  EXPECT_EQ(6, module()
+                   .entry_computation()
                    ->root_instruction()
                    ->literal()
                    .GetFirstElement<int32>());
 }
 
 TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(kConstantFoldReduce));
-  HloInstruction* add = module->computations().begin()->root_instruction();
+  ParseAndVerifyModule(kConstantFoldReduce);
+  HloInstruction* add = module().computations().begin()->root_instruction();
   LayoutUtil::ClearLayout(add->mutable_shape());
   HloConstantFolding const_folder;
-  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+  TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
   EXPECT_FALSE(result);
 
-  EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+  EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 939b511..a502fff 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -227,6 +227,14 @@
   return Status::OK();
 }
 
+Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
+  // Domain does not have any computation or data transfer.
+  current_should_compute_bottleneck_time_ = false;
+  current_properties_[kBytesAccessedKey] = 0;
+  current_properties_[kOptimalSecondsKey] = 0;
+  return Status::OK();
+}
+
 Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
   const Shape& lhs_shape = dot->operand(0)->shape();
   const Shape& rhs_shape = dot->operand(1)->shape();
@@ -507,8 +515,9 @@
     valid_position_counts.push_back(valid_position_count);
   }
 
-  const int64 fma_count =
-      input_feature * output_feature * batch * Product(valid_position_counts);
+  const int64 fma_count = (input_feature / convolution->feature_group_count()) *
+                          output_feature * batch *
+                          Product(valid_position_counts);
   current_properties_[kFlopsKey] = fma_count * kFmaFlops;
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 9bb3f12..46b4bbe 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@
   Status HandleRecvDone(const HloInstruction* recv_done) override;
   Status HandleConvert(const HloInstruction* convert) override;
   Status HandleCopy(const HloInstruction* copy) override;
+  Status HandleDomain(const HloInstruction* domain) override;
   Status HandleDot(const HloInstruction* dot) override;
   Status HandleConvolution(const HloInstruction* convolution) override;
   Status HandleFft(const HloInstruction* fft) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854ee..d76ce9e 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -203,6 +203,35 @@
             sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
 }
 
+TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
+  XlaBuilder builder("convolution");
+  auto input = Parameter(
+      &builder, 0,
+      ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
+                                 /*x_dim=*/20}),
+      "input");
+  auto kernel = Parameter(
+      &builder, 1,
+      ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
+                                 /*x_dim=*/3}),
+      "kernel");
+  Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
+
+  // Run HLO cost analysis.
+  auto hlo_module = BuildHloGraph(&builder);
+  HloCostAnalysis analysis(ShapeSize);
+  ASSERT_IS_OK(
+      hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+  // Output shape is [1x120x8x18] and each output element requires (3x3)
+  // FMAs and one FMA is 2 flops.
+  EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
+
+  // Bytes accessed is sum of inputs and output.
+  EXPECT_EQ(analysis.bytes_accessed(),
+            sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
+}
+
 TEST_F(HloCostAnalysisTest, Reduce) {
   XlaBuilder builder("reduce");
   auto input =
@@ -415,7 +444,7 @@
 TEST_F(HloCostAnalysisTest, TupleCost) {
   HloCostAnalysis analysis(ShapeSize);
   {
-    XlaBuilder builder("matmul");
+    XlaBuilder builder("tuple");
     auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
     auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
     Tuple(&builder, {x, y});
@@ -430,6 +459,30 @@
   EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
 }
 
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+  HloCostAnalysis analysis(ShapeSize);
+
+  HloComputation::Builder builder("domain");
+  auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+      0, ShapeUtil::MakeShape(F32, {123}), "x"));
+  auto y = builder.AddInstruction(
+      HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+  auto domain = builder.AddInstruction(
+      HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+  auto hlo_module = CreateNewModule();
+  hlo_module->AddEntryComputation(builder.Build());
+
+  EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+  ASSERT_IS_OK(domain->Accept(&analysis));
+
+  EXPECT_EQ(analysis.flop_count(*domain), 0);
+  EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+  EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
 TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
   XlaBuilder builder("BaseDilatedConvolution");
   auto input = Parameter(
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 19ffb46..b76c50b 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -61,15 +61,18 @@
 }
 
 StatusOr<HloInstruction*> MakeConvolveHlo(
-    HloInstruction* lhs, HloInstruction* rhs, const Window& window,
-    const ConvolutionDimensionNumbers& dimension_numbers) {
+    HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+    const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config) {
   HloComputation* computation = lhs->parent();
   CHECK_EQ(computation, rhs->parent());
-  TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape(
-                                                lhs->shape(), rhs->shape(),
-                                                window, dimension_numbers));
+  TF_ASSIGN_OR_RETURN(Shape convolve_shape,
+                      ShapeInference::InferConvolveShape(
+                          lhs->shape(), rhs->shape(), feature_group_count,
+                          window, dimension_numbers));
   return computation->AddInstruction(HloInstruction::CreateConvolve(
-      convolve_shape, lhs, rhs, window, dimension_numbers));
+      convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+      precision_config));
 }
 
 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
@@ -165,14 +168,15 @@
 }
 
 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
-                                     const DotDimensionNumbers& dim_numbers) {
+                                     const DotDimensionNumbers& dim_numbers,
+                                     const PrecisionConfig& precision_config) {
   HloComputation* computation = lhs->parent();
   CHECK_EQ(computation, rhs->parent());
   TF_ASSIGN_OR_RETURN(
       Shape dot_shape,
       ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
-  return computation->AddInstruction(
-      HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
+  return computation->AddInstruction(HloInstruction::CreateDot(
+      dot_shape, lhs, rhs, dim_numbers, precision_config));
 }
 
 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
@@ -317,18 +321,17 @@
   padding_config_dim.set_edge_padding_high(zeros_to_append);
   *padding_config.add_dimensions() = padding_config_dim;
 
-  HloInstruction* zero = computation->AddInstruction(
-      HloInstruction::CreateConstant(absl::make_unique<Literal>(
-          LiteralUtil::Zero(operand->shape().element_type()))));
+  HloInstruction* zero =
+      computation->AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::Zero(operand->shape().element_type())));
   return MakePadHlo(operand, zero, padding_config);
 }
 
 StatusOr<HloInstruction*> BroadcastZeros(
     HloComputation* computation, PrimitiveType element_type,
     absl::Span<const int64> broadcast_dimensions) {
-  HloInstruction* zero =
-      computation->AddInstruction(HloInstruction::CreateConstant(
-          absl::make_unique<Literal>(LiteralUtil::Zero(element_type))));
+  HloInstruction* zero = computation->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
   return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
                           /*result_shape_bounds=*/broadcast_dimensions);
 }
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index a1c4b37..b22058a 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -48,8 +48,9 @@
 // Creates a convolution HLO instruction and adds it to the computation
 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
 StatusOr<HloInstruction*> MakeConvolveHlo(
-    HloInstruction* lhs, HloInstruction* rhs, const Window& window,
-    const ConvolutionDimensionNumbers& dimension_numbers);
+    HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
+    const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config);
 
 // Creates a transpose HLO instruction and adds it to the computation containing
 // `operand`.
@@ -98,7 +99,8 @@
 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
 // and `rhs` (both must be in the same computation).
 StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
-                                     const DotDimensionNumbers& dim_numbers);
+                                     const DotDimensionNumbers& dim_numbers,
+                                     const PrecisionConfig& precision_config);
 
 // Creates a Map HLO instruction and adds it to the computation containing the
 // operands. All operands must be in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index eb6affa..e07a196 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -57,10 +57,10 @@
   entry_computation->set_root_instruction(first_1_dims_collapsed);
 
   HloEvaluator evaluator;
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
-                          evaluator.Evaluate<std::unique_ptr<Literal>>(
+  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+                          evaluator.Evaluate<Literal>(
                               *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
 }
 
 TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -78,13 +78,13 @@
 
   HloEvaluator evaluator;
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result_literal,
-      evaluator.Evaluate<std::unique_ptr<Literal>>(
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(
           *module,
           {LiteralUtil::CreateR3<int32>(
               {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
-  CHECK_EQ(*result_literal,
-           *LiteralUtil::CreateR2<int32>(
+  CHECK_EQ(result_literal,
+           LiteralUtil::CreateR2<int32>(
                {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
 }
 
@@ -103,10 +103,10 @@
 
   HloEvaluator evaluator;
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result_literal,
-      evaluator.Evaluate<std::unique_ptr<Literal>>(
-          *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(*module,
+                                  {LiteralUtil::CreateR1<int32>({9, 10})}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
 }
 
 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -124,10 +124,10 @@
 
   HloEvaluator evaluator;
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result_literal,
-      evaluator.Evaluate<std::unique_ptr<Literal>>(
-          *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(*module,
+                                  {LiteralUtil::CreateR1<int32>({9, 10})}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
 }
 
 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -144,10 +144,10 @@
   entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
 
   HloEvaluator evaluator;
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
-                          evaluator.Evaluate<std::unique_ptr<Literal>>(
-                              *module, {LiteralUtil::CreateR0<int32>(9)}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
+  TF_ASSERT_OK_AND_ASSIGN(
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(9)}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
 }
 
 TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -165,11 +165,11 @@
 
   HloEvaluator evaluator;
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result_literal,
-      evaluator.Evaluate<std::unique_ptr<Literal>>(
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(
           *module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
-  CHECK_EQ(*result_literal,
-           *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+  CHECK_EQ(result_literal,
+           LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
 }
 
 TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -187,10 +187,10 @@
   entry_computation->set_root_instruction(zero_padded_param);
 
   HloEvaluator evaluator;
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
-                          evaluator.Evaluate<std::unique_ptr<Literal>>(
+  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+                          evaluator.Evaluate<Literal>(
                               *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
 }
 
 TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -208,10 +208,10 @@
   entry_computation->set_root_instruction(zeros);
 
   HloEvaluator evaluator;
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
-                          evaluator.Evaluate<std::unique_ptr<Literal>>(
-                              *module, {LiteralUtil::CreateR0<int32>(0)}));
-  CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
+  TF_ASSERT_OK_AND_ASSIGN(
+      Literal result_literal,
+      evaluator.Evaluate<Literal>(*module, {LiteralUtil::CreateR0<int32>(0)}));
+  CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
 }
 
 TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -229,11 +229,11 @@
   entry_computation->set_root_instruction(zeros);
 
   HloEvaluator evaluator;
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
-                          evaluator.Evaluate<std::unique_ptr<Literal>>(
+  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
+                          evaluator.Evaluate<Literal>(
                               *module, {LiteralUtil::CreateR0<float>(0.0f)}));
-  CHECK_EQ(*result_literal,
-           *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+  CHECK_EQ(result_literal,
+           LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index cb367ad..b59c9ba 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -23,6 +23,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/container/inlined_vector.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,7 +35,6 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/hash/hash.h"
 
 namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 406d712..9b18b02 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -29,7 +29,7 @@
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -44,7 +44,7 @@
 namespace xla {
 namespace {
 
-class HloCseTest : public HloTestBase {
+class HloCseTest : public HloVerifiedTestBase {
  protected:
   HloCseTest() {}
 };
@@ -65,15 +65,15 @@
   EXPECT_EQ(3, computation->instruction_count());
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(2, computation->instruction_count());
   HloInstruction* constant = *computation->instructions().begin();
   EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
 
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
   auto expected = LiteralUtil::CreateR0<float>(84.0);
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -96,16 +96,16 @@
   EXPECT_THAT(add, op::Add(constant1, constant2));
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(2, computation->instruction_count());
   auto first_operand = add->operand(0);
   EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
   EXPECT_THAT(add, op::Add(first_operand, first_operand));
 
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
   auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -128,14 +128,14 @@
   EXPECT_THAT(add, op::Add(constant1, constant2));
 
   HloCSE cse(/*is_layout_sensitive=*/true);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(3, computation->instruction_count());
   EXPECT_THAT(add, op::Add(constant1, constant2));
 
-  auto result = ExecuteAndTransfer(std::move(module), {});
+  auto result = ExecuteAndTransfer(module->Clone(), {});
   auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
@@ -177,7 +177,7 @@
   EXPECT_EQ(20, computation->instruction_count());
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   // CSE will remove both the second float(42.0f) and the corresponding
   // convert/cast.
@@ -209,7 +209,7 @@
               op::Tuple(common_constant1, common_constant2, uncommon_constant));
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(3, computation->instruction_count());
   auto first_operand = tuple->operand(0);
@@ -240,7 +240,7 @@
   EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
 
   HloCSE cse(/*is_layout_sensitive=*/true);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(3, computation->instruction_count());
   auto first_operand = tuple->operand(0);
@@ -250,7 +250,7 @@
 
 // Test two identical while loops with same inputs
 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
     HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
 
     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -278,21 +278,20 @@
 %while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
 condition=%condition.1, body=%body
     }
-    )")
-                    .ValueOrDie();
+    )");
 
-  auto computation = module->entry_computation();
+  auto computation = module().entry_computation();
 
   EXPECT_EQ(5, computation->instruction_count());
   HloCSE cse(true);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
   EXPECT_EQ(4, computation->instruction_count());
 }
 
 // Test two while loops with same conditions, same inputs, but different
 // bodies
 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
     HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
 
     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -329,20 +328,19 @@
 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
 f32[]) %tuple.1), condition=%condition.1, body=%body2
     }
-    )")
-                    .ValueOrDie();
+    )");
 
-  auto computation = module->entry_computation();
+  auto computation = module().entry_computation();
 
   EXPECT_EQ(5, computation->instruction_count());
   HloCSE cse(true);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
   EXPECT_EQ(5, computation->instruction_count());
 }
 
 // Test two identical while loops with different inputs
 TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
     HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
 
     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -373,21 +371,20 @@
 condition=%condition.1, body=%body
     }
 
-    )")
-                    .ValueOrDie();
+    )");
 
-  auto computation = module->entry_computation();
+  auto computation = module().entry_computation();
 
   EXPECT_EQ(8, computation->instruction_count());
   HloCSE cse(true);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
   EXPECT_EQ(8, computation->instruction_count());
 }
 
 // Test two while loops with identical bodies and same inputs, but different
 // conditions
 TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
     HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
 
     %body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -414,14 +411,13 @@
       %while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
 condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
 f32[]) %tuple.1), condition=%condition.1, body=%body
-    })")
-                    .ValueOrDie();
+    })");
 
-  auto computation = module->entry_computation();
+  auto computation = module().entry_computation();
 
   EXPECT_EQ(5, computation->instruction_count());
   HloCSE cse(true);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
   EXPECT_EQ(5, computation->instruction_count());
 }
 
@@ -450,7 +446,7 @@
   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
 
   HloCSE cse(/*is_layout_sensitive=*/true);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(4, computation->instruction_count());
   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
@@ -481,7 +477,7 @@
   EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(3, computation->instruction_count());
   auto first_operand = tuple->operand(0);
@@ -516,7 +512,7 @@
 
   EXPECT_EQ(5, fused_computation->instruction_count());
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
   EXPECT_EQ(4, fused_computation->instruction_count());
 
   auto root = fused_computation->root_instruction();
@@ -565,7 +561,7 @@
   EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_TRUE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(5, computation->instruction_count());
   auto operand = tuple->operand(0);
@@ -599,7 +595,7 @@
   uint32 count_before = computation->instruction_count();
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(module).ValueOrDie());
 
   uint32 count_after = computation->instruction_count();
   EXPECT_EQ(count_before, count_after);
@@ -653,7 +649,7 @@
   VLOG(3) << "before: " << module->ToString();
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(module).ValueOrDie());
 
   VLOG(3) << "after: " << module->ToString();
 
@@ -663,7 +659,7 @@
 }
 
 TEST_F(HloCseTest, CompareComputations) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
     HloModule m
 
     add_computation {
@@ -684,12 +680,11 @@
       r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
       r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
       ROOT f2 = (f32[],f32[]) tuple(r1, r2)
-    })")
-                    .ValueOrDie();
+    })");
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
-  HloInstruction* root = module->entry_computation()->root_instruction();
+  EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+  HloInstruction* root = module().entry_computation()->root_instruction();
   EXPECT_EQ(root->operand(0), root->operand(1));
 }
 
@@ -708,13 +703,13 @@
   EXPECT_EQ(2, computation->instruction_count());
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(cse.Run(module).ValueOrDie());
 
   EXPECT_EQ(2, computation->instruction_count());
 }
 
 TEST_F(HloCseTest, Domain) {
-  auto module = ParseHloString(R"(
+  ParseAndVerifyModule(R"(
 HloModule module
 ENTRY %entry {
   %param = f32[] parameter(0), sharding={maximal device=0}
@@ -735,13 +730,11 @@
     domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
   %add = f32[] add(%domain.3, %domain.4)
   ROOT %sub = f32[] subtract(%add, %domain.5)
-})")
-                    .ValueOrDie();
+})");
 
   HloCSE cse(/*is_layout_sensitive=*/false);
-  EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
-  LOG(INFO) << "AAAAA " << module->ToString();
-  const HloInstruction* sub = module->entry_computation()->root_instruction();
+  EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+  const HloInstruction* sub = module().entry_computation()->root_instruction();
   const HloInstruction* add = sub->operand(0);
   EXPECT_EQ(add->operand(0), add->operand(1));
   EXPECT_NE(add->operand(0), sub->operand(1));
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index d1a96c1..510d636 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/test.h"
 
@@ -1261,9 +1262,10 @@
   auto entry = module_->AddEntryComputation(builder.Build());
   RunAnalysis(GetParam());
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  sequence.insert({entry, {param0, negate, param1, exp, add}});
-  SequentialHloOrdering ordering(module_.get(), sequence);
+  HloSchedule schedule(module_.get());
+  schedule.set_sequence(entry, {param0, negate, param1, exp, add});
+  TF_ASSERT_OK(schedule.Verify());
+  SequentialHloOrdering ordering(schedule);
 
   // Entry parameters interfere as if they are defined simultaneously at
   // the very beginning.
@@ -1339,14 +1341,16 @@
   bool ssa_form = GetParam();
   RunAnalysis(ssa_form);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  sequence.insert({entry, {param, xla_while}});
-  sequence.insert({condition, {cond_param, cond_constant}});
+  HloSchedule schedule(module_.get());
+  schedule.set_sequence(entry, {param, xla_while});
+  schedule.set_sequence(condition, {cond_param, cond_constant});
   // Construct the order such that 'constant' and its use 'exp' are before
   // body_param.
-  sequence.insert({body, {constant, exp, body_param, add}});
+  schedule.set_sequence(
+      body, {constant, exp, body_param, add, dead_constant, dead_negate});
+  TF_ASSERT_OK(schedule.Verify());
 
-  SequentialHloOrdering ordering(module_.get(), sequence);
+  SequentialHloOrdering ordering(schedule);
 
   // 'add' is live out of the body and will interfere with an later instructions
   // such as 'dead_constant' and 'dead_negate'.
@@ -1476,11 +1480,10 @@
   auto entry = module_->AddEntryComputation(builder.Build());
   RunAnalysis(GetParam());
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  std::vector<const HloInstruction*> order = {param, negate, exp, add};
-  sequence.emplace(entry, order);
-
-  SequentialHloOrdering ordering(module_.get(), sequence);
+  HloSchedule schedule(module_.get());
+  schedule.set_sequence(entry, {param, negate, exp, add});
+  TF_ASSERT_OK(schedule.Verify());
+  SequentialHloOrdering ordering(schedule);
 
   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
@@ -2334,8 +2337,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
   auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+      HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
 
   auto one = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 8b2846e..113fd18 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -51,6 +51,10 @@
   return FindOrDefault(instruction_to_domain_, instruction, -1);
 }
 
+int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+  return FindOrDie(domain_metadata_id_, instruction);
+}
+
 Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
   TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
   // We only check operands, so we are sure to not process the empty domain from
@@ -93,6 +97,43 @@
                         CreateDomain(instruction, instructions_post_order));
     TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
   }
+  TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
+  return Status::OK();
+}
+
+Status HloDomainMap::PopulateDomainMetadataMap() {
+  auto hash = [](const DomainMetadata* m) { return m->Hash(); };
+  auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
+    return a->Matches(*b);
+  };
+  tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
+                           decltype(equal)>
+      domain_metadata(1024, hash, equal);
+
+  for (auto& domain : instruction_domains_) {
+    int64 domain_metadata_id = -1;
+    if (!domain->enter_domains.empty()) {
+      const HloInstruction* domain_instruction = *domain->enter_domains.begin();
+      domain_metadata_id =
+          domain_metadata
+              .insert({&domain_instruction->user_side_metadata(),
+                       domain_metadata.size() + 1})
+              .first->second;
+    } else if (!domain->exit_domains.empty()) {
+      const HloInstruction* domain_instruction = *domain->exit_domains.begin();
+      domain_metadata_id =
+          domain_metadata
+              .insert({&domain_instruction->operand_side_metadata(),
+                       domain_metadata.size() + 1})
+              .first->second;
+    } else {
+      domain_metadata_id = 0;
+    }
+    TF_RET_CHECK(domain_metadata_id >= 0);
+    for (HloInstruction* instruction : domain->instructions) {
+      domain_metadata_id_[instruction] = domain_metadata_id;
+    }
+  }
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 6331092..56b557d 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -69,6 +69,11 @@
   // instruction is not found within any domain.
   int64 GetDomainId(HloInstruction* instruction) const;
 
+  // Returns the unique id of the domain metadata for the domain the given
+  // instruction belongs to. The given instruction must not be a kDomain
+  // instruction since each domain instruction is associated with 2 domains.
+  int64 GetDomainMetadataId(HloInstruction* instruction) const;
+
  private:
   // Map used for representing instruction ordering, i.e.
   // order_map[a] < order_map[b] means a must be ordered before b.
@@ -109,9 +114,14 @@
       const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
       const InstructionOrderMap& instructions_order);
 
+  // Populates domain_metadata_id_ that maps each HloInstruction to the unique
+  // ID of its associated domain metatadata.
+  Status PopulateDomainMetadataMap();
+
   string domain_kind_;
   std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
   tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
+  tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 6c142ee..302807f 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -72,6 +72,9 @@
   // two matches.
   virtual bool Matches(const DomainMetadata& other) const = 0;
 
+  // Returns the hash value of the metadata.
+  virtual size_t Hash() const = 0;
+
   // Returns a string representation of the metadata.
   virtual string ToString() const = 0;
 };
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index 974ab94..43e74d2 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -99,6 +99,8 @@
 
   static absl::string_view KindName() { return "opname"; }
 
+  size_t Hash() const override { return std::hash<string>()(opname_); }
+
  private:
   string opname_;
 };
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 441dcad..06b6d5b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -53,11 +53,9 @@
 
 namespace {
 
-
 template <typename OperandT>
-StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode,
-                                           LiteralSlice lhs_literal,
-                                           LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare(const Shape& shape, HloOpcode opcode,
+                          LiteralSlice lhs_literal, LiteralSlice rhs_literal) {
   std::function<bool(OperandT, OperandT)> compare_op;
   switch (opcode) {
     case HloOpcode::kEq:
@@ -95,9 +93,9 @@
                  << HloOpcodeString(opcode);
   }
 
-  auto result = absl::make_unique<Literal>(shape);
+  Literal result(shape);
   TF_RETURN_IF_ERROR(
-      result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+      result.Populate<bool>([&](absl::Span<const int64> multi_index) {
         return compare_op(lhs_literal.Get<OperandT>(multi_index),
                           rhs_literal.Get<OperandT>(multi_index));
       }));
@@ -106,9 +104,9 @@
 }
 
 template <>
-StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
-    const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal,
-    LiteralSlice rhs_literal) {
+StatusOr<Literal> Compare<complex64>(const Shape& shape, HloOpcode opcode,
+                                     LiteralSlice lhs_literal,
+                                     LiteralSlice rhs_literal) {
   std::function<bool(complex64, complex64)> compare_op;
   switch (opcode) {
     case HloOpcode::kEq:
@@ -126,9 +124,9 @@
                  << HloOpcodeString(opcode);
   }
 
-  auto result = absl::make_unique<Literal>(shape);
+  Literal result(shape);
   TF_RETURN_IF_ERROR(
-      result->Populate<bool>([&](absl::Span<const int64> multi_index) {
+      result.Populate<bool>([&](absl::Span<const int64> multi_index) {
         return compare_op(lhs_literal.Get<complex64>(multi_index),
                           rhs_literal.Get<complex64>(multi_index));
       }));
@@ -194,7 +192,7 @@
 }
 
 template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
     const HloModule& module, absl::Span<const LiteralPtr> arg_literals) {
   XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString());
 
@@ -207,11 +205,21 @@
   TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this));
 
   return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())
-      .CloneToUnique();
+      .Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+    const HloModule& module, absl::Span<const Literal> arg_literals) {
+  std::vector<const Literal*> arg_literal_ptrs;
+  for (const auto& literal_ptr : arg_literals) {
+    arg_literal_ptrs.push_back(&literal_ptr);
+  }
+  return Evaluate<const Literal*>(module, arg_literal_ptrs);
 }
 
 template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
     const HloComputation& computation,
     absl::Span<const LiteralPtr> arg_literals) {
   CHECK(computation.parent() != nullptr);
@@ -225,11 +233,21 @@
   }
 
   TF_RETURN_IF_ERROR(computation.Accept(this));
-  return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique();
+  return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
+}
+
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+    const HloComputation& computation, absl::Span<const Literal> arg_literals) {
+  std::vector<const Literal*> arg_literal_ptrs;
+  for (const auto& literal_ptr : arg_literals) {
+    arg_literal_ptrs.push_back(&literal_ptr);
+  }
+  return Evaluate<const Literal*>(computation, arg_literal_ptrs);
 }
 
 template <typename LiteralPtr>
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
+StatusOr<Literal> HloEvaluator::Evaluate(
     HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals) {
   TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction));
 
@@ -248,18 +266,27 @@
               << input_literal->ToString();
       TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape()));
 
-      evaluated_[operand] = input_literal->CloneToUnique();
+      evaluated_[operand] = input_literal->Clone();
     }
   }
 
   TF_RETURN_IF_ERROR(Preprocess(instruction));
   TF_RETURN_IF_ERROR(instruction->Visit(this));
   TF_RETURN_IF_ERROR(Postprocess(instruction));
-  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+  return GetEvaluatedLiteralFor(instruction).Clone();
 }
 
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
-    HloInstruction* instruction) {
+template <>
+StatusOr<Literal> HloEvaluator::Evaluate<Literal>(
+    HloInstruction* instruction, absl::Span<const Literal> arg_literals) {
+  std::vector<const Literal*> arg_literal_ptrs;
+  for (const auto& literal : arg_literals) {
+    arg_literal_ptrs.push_back(&literal);
+  }
+  return Evaluate<const Literal*>(instruction, arg_literal_ptrs);
+}
+
+StatusOr<Literal> HloEvaluator::Evaluate(HloInstruction* instruction) {
   if (instruction->opcode() == HloOpcode::kParameter) {
     return tensorflow::errors::FailedPrecondition(
         "Cannot evaluate a parameter.");
@@ -275,21 +302,22 @@
   TF_RETURN_IF_ERROR(Preprocess(instruction));
   TF_RETURN_IF_ERROR(instruction->Visit(this));
   TF_RETURN_IF_ERROR(Postprocess(instruction));
-  return GetEvaluatedLiteralFor(instruction).CloneToUnique();
+  return GetEvaluatedLiteralFor(instruction).Clone();
 }
 
-std::unique_ptr<Literal> HloEvaluator::TryEvaluate(
-    HloInstruction* instruction) {
+bool HloEvaluator::TryEvaluate(HloInstruction* instruction, Literal* result) {
+  CHECK(result != nullptr);
   auto result_or = Evaluate(instruction);
   if (!result_or.ok()) {
     VLOG(1) << "TryEvaluate failed:" << result_or.status();
-    return nullptr;
+    return false;
   }
 
-  return result_or.ConsumeValueOrDie();
+  *result = result_or.ConsumeValueOrDie();
+  return true;
 }
 
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
+StatusOr<Literal> HloEvaluator::EvaluateWithSubstitutions(
     const HloInstruction* instruction,
     const std::unordered_map<const HloInstruction*, const Literal*>&
         substitutions) {
@@ -300,7 +328,7 @@
       owned_operands.push_back(operand->Clone());
     } else {
       owned_operands.push_back(
-          HloInstruction::CreateConstant(it->second->CloneToUnique()));
+          HloInstruction::CreateConstant(it->second->Clone()));
     }
   }
 
@@ -317,12 +345,12 @@
   return result;
 }
 
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseBinaryOp(
     HloOpcode opcode, const Literal& lhs, const Literal& rhs) {
   std::unique_ptr<HloInstruction> lhs_instr =
-      HloInstruction::CreateConstant(lhs.CloneToUnique());
+      HloInstruction::CreateConstant(lhs.Clone());
   std::unique_ptr<HloInstruction> rhs_instr =
-      HloInstruction::CreateConstant(rhs.CloneToUnique());
+      HloInstruction::CreateConstant(rhs.Clone());
 
   std::unique_ptr<HloInstruction> cloned_instruction =
       HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(),
@@ -332,10 +360,10 @@
   return result;
 }
 
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
+StatusOr<Literal> HloEvaluator::EvaluateElementwiseUnaryOp(
     HloOpcode opcode, const Literal& operand) {
   std::unique_ptr<HloInstruction> operand_instr =
-      HloInstruction::CreateConstant(operand.CloneToUnique());
+      HloInstruction::CreateConstant(operand.Clone());
 
   std::unique_ptr<HloInstruction> cloned_instruction =
       HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
@@ -344,13 +372,14 @@
   return result;
 }
 
-StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
-    const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+StatusOr<Literal> HloEvaluator::EvaluateDotOp(
+    const DotDimensionNumbers& dim_numbers,
+    const PrecisionConfig& precision_config, const Literal& lhs,
     const Literal& rhs) {
   std::unique_ptr<HloInstruction> lhs_instr =
-      HloInstruction::CreateConstant(lhs.CloneToUnique());
+      HloInstruction::CreateConstant(lhs.Clone());
   std::unique_ptr<HloInstruction> rhs_instr =
-      HloInstruction::CreateConstant(rhs.CloneToUnique());
+      HloInstruction::CreateConstant(rhs.Clone());
 
   TF_ASSIGN_OR_RETURN(
       Shape dot_shape,
@@ -358,7 +387,7 @@
 
   std::unique_ptr<HloInstruction> cloned_instruction =
       HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
-                                dim_numbers);
+                                dim_numbers, precision_config);
   return Evaluate(cloned_instruction.get());
 }
 
@@ -371,7 +400,7 @@
       << ", but input literal shape is: "
       << ShapeUtil::HumanString(input_literal->shape());
 
-  evaluated_[parameter] = input_literal->CloneToUnique();
+  evaluated_[parameter] = input_literal->Clone();
   return Status::OK();
 }
 
@@ -421,7 +450,7 @@
 
   for (auto operand : operands) {
     const Shape& operand_shape = operand->shape();
-    TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
+    TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
         GetEvaluatedLiteralFor(operand), source_indices, dest_indices,
         AsInt64Slice(operand_shape.dimensions())));
     dest_indices[concat_dim] +=
@@ -824,7 +853,7 @@
 // there is one) to `reshaped_start_indices`.
 static StatusOr<std::reference_wrapper<const Literal>> ReshapedGatherIndices(
     int64 index_vector_dim, const Literal& start_indices,
-    std::unique_ptr<Literal>* reshaped_start_indices) {
+    Literal* reshaped_start_indices) {
   if (start_indices.shape().dimensions_size() != index_vector_dim) {
     return std::cref(start_indices);
   }
@@ -834,16 +863,16 @@
   new_shape.push_back(1);
   TF_ASSIGN_OR_RETURN(*reshaped_start_indices,
                       start_indices.Reshape(new_shape));
-  return std::cref(**reshaped_start_indices);
+  return std::cref(*reshaped_start_indices);
 }
 
 Status HloEvaluator::HandleGather(HloInstruction* gather) {
-  std::unique_ptr<Literal> result = Literal::CreateFromShape(gather->shape());
+  Literal result = Literal::CreateFromShape(gather->shape());
   const Shape& shape = gather->shape();
   const GatherDimensionNumbers& dim_numbers =
       gather->gather_dimension_numbers();
   const Literal& operand = GetEvaluatedLiteralFor(gather->operand(0));
-  std::unique_ptr<Literal> reshaped_start_indices;
+  Literal reshaped_start_indices;
   TF_ASSIGN_OR_RETURN(
       const Literal& start_indices,
       ReshapedGatherIndices(dim_numbers.index_vector_dim(),
@@ -908,7 +937,7 @@
       DCHECK_LT(input_index[i], operand_shape.dimensions(i));
     }
     TF_RETURN_IF_ERROR(
-        result->CopyElementFrom(operand, input_index, output_index));
+        result.CopyElementFrom(operand, input_index, output_index));
     return true;
   };
 
@@ -940,8 +969,14 @@
   // Checks that operand's dimensions are the same as the broadcast's
   // dimensions along the dimensions to be broadcasted.
   for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
-    TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
-                 operand.shape().dimensions(i));
+    auto operand_dim_size = operand.shape().dimensions(i);
+    auto broadcast_dim_size =
+        broadcast->shape().dimensions(broadcast->dimensions(i));
+    TF_RET_CHECK(operand_dim_size == broadcast_dim_size) << absl::StreamFormat(
+        "Operand dimension %d is broadcast to output dimension %d, but the "
+        "sizes of these two dims do not match (%d vs %d): %s",
+        i, broadcast->dimensions(i), operand_dim_size, broadcast_dim_size,
+        broadcast->ToString());
   }
 
   TF_ASSIGN_OR_RETURN(
@@ -971,18 +1006,16 @@
 
   const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand);
 
-  evaluated_[get_tuple_element] = absl::make_unique<Literal>(
-      ShapeUtil::GetTupleElementShape(operand->shape(), index));
-  return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal,
-                                                 /*dest_shape_index=*/{},
-                                                 /*src_shape_index=*/{index});
+  evaluated_[get_tuple_element] =
+      Literal(ShapeUtil::GetTupleElementShape(operand->shape(), index));
+  return evaluated_[get_tuple_element].CopyFrom(operand_tuple_literal,
+                                                /*dest_shape_index=*/{},
+                                                /*src_shape_index=*/{index});
 }
 
 Status HloEvaluator::HandleCopy(HloInstruction* copy) {
   TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape()));
-
-  auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique();
-  evaluated_[copy] = std::move(result);
+  evaluated_[copy] = GetEvaluatedLiteralFor(copy->operand(0)).Clone();
   return Status::OK();
 }
 
@@ -998,7 +1031,7 @@
   }
 
   HloEvaluator embedded_evaluator;
-  std::unique_ptr<Literal> result =
+  Literal result =
       embedded_evaluator.Evaluate<const Literal*>(*computation, arg_literals)
           .ConsumeValueOrDie();
 
@@ -1030,7 +1063,7 @@
   }
 
   HloEvaluator embedded_evaluator;
-  std::unique_ptr<Literal> result =
+  Literal result =
       embedded_evaluator
           .Evaluate<const Literal*>(*readded_computation, arg_literals)
           .ConsumeValueOrDie();
@@ -1050,7 +1083,7 @@
   auto* false_computation = conditional->false_computation();
 
   HloEvaluator embedded_evaluator;
-  std::unique_ptr<Literal> result;
+  Literal result;
   if (pred.Get<bool>({})) {
     result = embedded_evaluator
                  .Evaluate<const Literal*>(*true_computation,
@@ -1075,9 +1108,9 @@
   // If predicate is of scalar type, no element-wise selection would be needed.
   if (ShapeUtil::IsScalar(pred.shape())) {
     if (pred.Get<bool>({})) {
-      evaluated_[select] = on_true.CloneToUnique();
+      evaluated_[select] = on_true.Clone();
     } else {
-      evaluated_[select] = on_false.CloneToUnique();
+      evaluated_[select] = on_false.Clone();
     }
     return Status::OK();
   }
@@ -1091,9 +1124,9 @@
   const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
 
   if (pred.Get<bool>({})) {
-    evaluated_[tuple_select] = on_true.CloneToUnique();
+    evaluated_[tuple_select] = on_true.Clone();
   } else {
-    evaluated_[tuple_select] = on_false.CloneToUnique();
+    evaluated_[tuple_select] = on_false.Clone();
   }
   return Status::OK();
 }
@@ -1102,7 +1135,7 @@
   HloComputation* cond_comp = while_hlo->while_condition();
   HloComputation* body_comp = while_hlo->while_body();
   // Initialize the loop carried valued with the input to the While instruction.
-  auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).CloneToUnique();
+  auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone();
   bool keep_going = true;
   int64 iteration_count = 0;
   HloEvaluator cond_evaluator(max_loop_iterations_);
@@ -1112,13 +1145,13 @@
       return InvalidArgument("Loop %s exceeded loop iteration limit (%d).",
                              while_hlo->name(), max_loop_iterations_);
     }
-    TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>(
-                                           *cond_comp, {lcv.get()}));
-    keep_going = cond_val->GetFirstElement<bool>();
+    TF_ASSIGN_OR_RETURN(auto cond_val,
+                        cond_evaluator.Evaluate<Literal*>(*cond_comp, {&lcv}));
+    keep_going = cond_val.GetFirstElement<bool>();
     if (keep_going) {
       TF_ASSIGN_OR_RETURN(auto body_val, loop_body_evaluator.Evaluate<Literal*>(
-                                             *body_comp, {lcv.get()}));
-      VLOG(3) << "Loop iteration result: " << body_val->ToString();
+                                             *body_comp, {&lcv}));
+      VLOG(3) << "Loop iteration result: " << body_val.ToString();
       lcv = std::move(body_val);
       cond_evaluator.ResetVisitStates();
       loop_body_evaluator.ResetVisitStates();
@@ -1133,9 +1166,9 @@
 // hoops to make this work.
 namespace {
 template <typename KeyType, typename ValueType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal(
-    HloInstruction* sort, const Literal& keys_literal,
-    const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
+                                       const Literal& keys_literal,
+                                       const Literal& values_literal) {
   auto rank = ShapeUtil::Rank(keys_literal.shape());
   TF_RET_CHECK(
       ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
@@ -1173,57 +1206,55 @@
       result_keys.push_back(key_value.first);
       result_values.push_back(key_value.second);
     }
-    auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape());
-    result_keys_literal->PopulateR1(absl::Span<const KeyType>(result_keys));
-    auto result_values_literal =
-        absl::make_unique<Literal>(values_literal.shape());
-    result_values_literal->PopulateR1(
+    Literal result_keys_literal(keys_literal.shape());
+    result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
+    Literal result_values_literal(values_literal.shape());
+    result_values_literal.PopulateR1(
         absl::Span<const ValueType>(result_values));
     return std::make_pair(std::move(result_keys_literal),
                           std::move(result_values_literal));
   };
 
-  std::unique_ptr<Literal> result_tuple;
+  Literal result_tuple;
   if (rank == 1) {
     auto result_pair = sort_r1(keys_literal, values_literal);
-    result_tuple = LiteralUtil::MakeTuple(
-        {result_pair.first.get(), result_pair.second.get()});
+    result_tuple =
+        LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
   } else {
     // For R2 sort, the desired semantics are to sort each matrix row
     // independently.
-    auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape());
-    auto values_result_literal =
-        absl::make_unique<Literal>(values_literal.shape());
+    Literal keys_result_literal(keys_literal.shape());
+    Literal values_result_literal(values_literal.shape());
     int64 r1_length = keys_literal.shape().dimensions(1);
     for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
       TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
                           keys_literal.Slice({row, 0}, {row + 1, r1_length})
-                              ->Reshape({r1_length}));
+                              .Reshape({r1_length}));
       TF_ASSIGN_OR_RETURN(auto values_r1_slice,
                           values_literal.Slice({row, 0}, {row + 1, r1_length})
-                              ->Reshape({r1_length}));
-      auto r1_result_pair = sort_r1(*keys_r1_slice, *values_r1_slice);
+                              .Reshape({r1_length}));
+      auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
       TF_ASSIGN_OR_RETURN(auto sorted_keys,
-                          r1_result_pair.first->Reshape({1, r1_length}));
+                          r1_result_pair.first.Reshape({1, r1_length}));
       TF_ASSIGN_OR_RETURN(auto sorted_values,
-                          r1_result_pair.second->Reshape({1, r1_length}));
-      TF_RETURN_IF_ERROR(keys_result_literal->CopySliceFrom(
-          *sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
-      TF_RETURN_IF_ERROR(values_result_literal->CopySliceFrom(
-          *sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
+                          r1_result_pair.second.Reshape({1, r1_length}));
+      TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+          sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
+      TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+          sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
     }
-    result_tuple = LiteralUtil::MakeTuple(
-        {keys_result_literal.get(), values_result_literal.get()});
+    result_tuple =
+        LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
   }
 
-  VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+  VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
   return std::move(result_tuple);
 }
 
 template <typename KeyType>
-StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
-    HloInstruction* sort, const Literal& keys_literal,
-    const Literal& values_literal) {
+StatusOr<Literal> EvaluateSortCurried(HloInstruction* sort,
+                                      const Literal& keys_literal,
+                                      const Literal& values_literal) {
   switch (sort->operand(1)->shape().element_type()) {
     case F32:
       return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
@@ -1242,9 +1273,9 @@
   }
 }
 
-StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
-                                                const Literal& keys_literal,
-                                                const Literal& values_literal) {
+StatusOr<Literal> EvaluateSort(HloInstruction* sort,
+                               const Literal& keys_literal,
+                               const Literal& values_literal) {
   switch (sort->operand(0)->shape().element_type()) {
     case F32:
       return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
@@ -1308,33 +1339,25 @@
 Status HloEvaluator::Postprocess(HloInstruction* hlo) {
   VLOG(2) << "Finished visiting " << hlo->ToString()
           << "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+  // Out of convenience the literal may have been produced with a different
+  // layout. Relayout as indicated by the HLO instruction.
+  if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+                                        hlo->shape())) {
+    evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+  }
   return Status::OK();
 }
 
 // Explicit instantiation of templatized Evaluate* methods.
 //
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
     const HloModule& module, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
-    const HloModule& module,
-    absl::Span<const std::unique_ptr<Literal>> arg_literals);
 
-template StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate<
-    const Literal*>(const HloComputation& computation,
-                    absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
     const HloComputation& computation,
-    absl::Span<const std::unique_ptr<Literal>> arg_literals);
+    absl::Span<const Literal* const> arg_literals);
 
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<const Literal*>(
+template StatusOr<Literal> HloEvaluator::Evaluate<const Literal*>(
     HloInstruction* instruction, absl::Span<const Literal* const> arg_literals);
-template StatusOr<std::unique_ptr<Literal>>
-HloEvaluator::Evaluate<std::unique_ptr<Literal>>(
-    HloInstruction* instruction,
-    absl::Span<const std::unique_ptr<Literal>> arg_literals);
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index c2d49e5..21e676d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -47,11 +47,11 @@
   // Precondition: The indices of arg_literals correspond to the parameter
   // numbers of the HLO parameters in the computation. See comment below for an
   // example.
-  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+  // `LiteralPtr` accepts either Literal or const Literal*
   // type.
   template <typename LiteralPtr>
-  StatusOr<std::unique_ptr<Literal>> Evaluate(
-      const HloModule& module, absl::Span<const LiteralPtr> arg_literals);
+  StatusOr<Literal> Evaluate(const HloModule& module,
+                             absl::Span<const LiteralPtr> arg_literals);
 
   // Evaluates an HLO computation and an array of pointers to literals.
   // Returns the evaluated result as a literal if successful.
@@ -69,12 +69,11 @@
   // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
   // 1 in this computation. The input literals array will then have its first
   // literal map to Parameter0 and the second map to Parameter1.
-  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+  // `LiteralPtr` accepts either Literal or const Literal*
   // type.
   template <typename LiteralPtr>
-  StatusOr<std::unique_ptr<Literal>> Evaluate(
-      const HloComputation& computation,
-      absl::Span<const LiteralPtr> arg_literals);
+  StatusOr<Literal> Evaluate(const HloComputation& computation,
+                             absl::Span<const LiteralPtr> arg_literals);
 
   // Evaluates a single HLO instruction and an array of pointers to literals.
   // Return the evaluated result as literal if successful.
@@ -82,41 +81,43 @@
   // 1. argument literals correspond to the input instruction's parameters in
   // their post-ordering.
   // 2. the instruction's operands must be of either Parameter or Constant type.
-  // `LiteralPtr` accepts either std::unique_ptr<Literal> or const Literal*
+  // `LiteralPtr` accepts either Literal or const Literal*
   // type.
   template <typename LiteralPtr>
-  StatusOr<std::unique_ptr<Literal>> Evaluate(
-      HloInstruction* instruction, absl::Span<const LiteralPtr> arg_literals);
+  StatusOr<Literal> Evaluate(HloInstruction* instruction,
+                             absl::Span<const LiteralPtr> arg_literals);
 
   // Evaluates a single HLO instruction with constant operands.
   // Returns the evaluated result as literal if successful.
   // Precondition:
   // 1. all operands of the input instruction are constants.
   // 2. the instruction is not a Parameter operation.
-  StatusOr<std::unique_ptr<Literal>> Evaluate(HloInstruction* instruction);
+  StatusOr<Literal> Evaluate(HloInstruction* instruction);
 
-  // Same as Evaluate, except returning nullptr on error.
-  std::unique_ptr<Literal> TryEvaluate(HloInstruction* instruction);
+  // Same as Evaluate, except returning false on error and accepts an output
+  // pointer.
+  bool TryEvaluate(HloInstruction* instruction, Literal* result);
 
   // Evaluates a single HLO instruction, substituting the given literals for
   // some of the instruction's operands.
   //
   // For example, given instruction = op(A, B, C) and the map
   // {A = x, C = y}, this evaluates op(x, B, y).
-  StatusOr<std::unique_ptr<Literal>> EvaluateWithSubstitutions(
+  StatusOr<Literal> EvaluateWithSubstitutions(
       const HloInstruction* instruction,
       const std::unordered_map<const HloInstruction*, const Literal*>&
           substitutions);
 
-  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseBinaryOp(
-      HloOpcode opcode, const Literal& lhs, const Literal& rhs);
+  StatusOr<Literal> EvaluateElementwiseBinaryOp(HloOpcode opcode,
+                                                const Literal& lhs,
+                                                const Literal& rhs);
 
-  StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
-      HloOpcode opcode, const Literal& operand);
+  StatusOr<Literal> EvaluateElementwiseUnaryOp(HloOpcode opcode,
+                                               const Literal& operand);
 
-  StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
-      const DotDimensionNumbers& dim_numbers, const Literal& lhs,
-      const Literal& rhs);
+  StatusOr<Literal> EvaluateDotOp(const DotDimensionNumbers& dim_numbers,
+                                  const PrecisionConfig& precision_config,
+                                  const Literal& lhs, const Literal& rhs);
 
  protected:
   // Make HloEvaluatorTypedVisitor a friend because it is logically part of this
@@ -196,7 +197,7 @@
     auto it = evaluated_.find(hlo);
     CHECK(it != evaluated_.end())
         << "could not find evaluated value for: " << hlo->ToString();
-    return *(it->second);
+    return it->second;
   }
 
   // Tracks the HLO instruction and its evaluated literal result.
@@ -204,12 +205,13 @@
   // that are no longer a parent for any other subsequent instruction in
   // post-orderring.
   // Must be cleared for each evaluation.
-  tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<Literal>>
-      evaluated_;
+  // Storing Literal in place require the container to have pointer stability so
+  // we cannot use FlatMap any more.
+  std::unordered_map<const HloInstruction*, Literal> evaluated_;
 
  private:
   template <typename ReturnT, typename NativeT>
-  static StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
+  static StatusOr<Literal> ElementWiseUnaryOpImpl(
       HloInstruction* instruction,
       const std::function<ReturnT(NativeT)>& unary_op,
       const Literal& operand_literal) {
@@ -226,9 +228,9 @@
           ShapeUtil::HumanString(operand->shape()));
     }
 
-    auto result = absl::make_unique<Literal>(shape);
+    Literal result(shape);
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
           return unary_op(operand_literal.Get<NativeT>(multi_index));
         }));
     return std::move(result);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 7e490d7..01e8856 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,15 +52,11 @@
 class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
                          public HloVerifiedTestBase {
  protected:
-  HloEvaluatorTest()
-      : HloVerifiedTestBase(/*layout_sensitive=*/false,
-                            /*allow_mixed_precision=*/false),
-        use_bfloat16_(GetParam()) {
+  HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
     evaluator_ = absl::make_unique<HloEvaluator>();
   }
 
-  std::unique_ptr<Literal> Evaluate(
-      absl::Span<const Literal* const> arg_literals = {}) {
+  Literal Evaluate(absl::Span<const Literal* const> arg_literals = {}) {
     if (use_bfloat16_) {
       // In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
       auto type_converter = HloElementTypeConverter(F32, BF16);
@@ -72,39 +68,37 @@
 
   std::unique_ptr<HloEvaluator> evaluator_;
 
-  void TestUnaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
-                   std::unique_ptr<Literal> input, float aabs = 0) {
+  void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
+                   float aabs = 0) {
     HloComputation::Builder b(TestName());
     auto c1 =
         b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
-    b.AddInstruction(
-        HloInstruction::CreateUnary(expected->shape(), opcode, c1));
+    b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
     module().AddEntryComputation(b.Build());
 
-    std::unique_ptr<Literal> result = Evaluate();
+    Literal result = Evaluate();
 
-    auto element_type = expected->shape().element_type();
+    auto element_type = expected.shape().element_type();
     if (element_type == F32 || element_type == F64) {
       ErrorSpec error(aabs);
-      EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
+      EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
     } else {
-      EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+      EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
     }
   }
 
-  void TestBinaryOp(HloOpcode opcode, std::unique_ptr<Literal> expected,
-                    std::unique_ptr<Literal> lhs,
-                    std::unique_ptr<Literal> rhs) {
+  void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
+                    Literal rhs) {
     HloComputation::Builder b(TestName());
     auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
     auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
     b.AddInstruction(
-        HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2));
+        HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
     module().AddEntryComputation(b.Build());
 
-    std::unique_ptr<Literal> result = Evaluate();
+    Literal result = Evaluate();
 
-    EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+    EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
   }
 
   bool use_bfloat16_;
@@ -120,7 +114,7 @@
   auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
   auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
 
-  Shape shape = low->shape();
+  Shape shape = low.shape();
   HloComputation::Builder b(TestName());
   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -129,11 +123,11 @@
       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -141,7 +135,7 @@
   auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
   auto high = LiteralUtil::CreateR0<float>(1.f);
 
-  Shape shape = value->shape();
+  Shape shape = value.shape();
   HloComputation::Builder b(TestName());
   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
@@ -150,11 +144,11 @@
       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -164,7 +158,7 @@
   auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
   auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
 
-  Shape shape = on_true->shape();
+  Shape shape = on_true.shape();
   HloComputation::Builder b(TestName());
   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
   auto c2 =
@@ -175,11 +169,11 @@
       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate({});
+  Literal result = Evaluate({});
 
   auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -298,7 +292,7 @@
   auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
   auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
   auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
-  std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
+  std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
 
   Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
 
@@ -316,11 +310,11 @@
                                                 lhs_instruction, param_rhs2));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate(args);
+  Literal result = Evaluate(args);
 
   auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 // Verifies Reshape operation is correctly evaluated.
@@ -330,7 +324,7 @@
   TF_ASSERT_OK_AND_ASSIGN(auto literal,
                           LiteralUtil::CreateRandomLiteral<F32>(
                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
-  auto literal_clone = literal->CloneToUnique();
+  auto literal_clone = literal.Clone();
   HloInstruction* literal_instruction =
       b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
 
@@ -340,14 +334,13 @@
       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate({});
+  Literal result = Evaluate({});
 
   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
-  result->EachCell<NativeT>(
-      [&](absl::Span<const int64> indices, NativeT value) {
-        std::vector<int64> rindexes = Permute(permutation, indices);
-        EXPECT_NEAR(value, literal_clone->Get<NativeT>(rindexes), 0.031250);
-      });
+  result.EachCell<NativeT>([&](absl::Span<const int64> indices, NativeT value) {
+    std::vector<int64> rindexes = Permute(permutation, indices);
+    EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
+  });
 }
 
 // Verifies Broadcast operation is correctly evaluated.
@@ -359,12 +352,12 @@
   HloInstruction* literal_instruction = b.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
   b.AddInstruction(HloInstruction::CreateBroadcast(
-      output_literal->shape(), literal_instruction, {1, 2}));
+      output_literal.shape(), literal_instruction, {1, 2}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate({});
+  Literal result = Evaluate({});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
 }
 
 TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -377,13 +370,13 @@
       HloInstruction::CreateConstant(std::move(input_literal)));
   // Broadcast dimension should be empty in the case of scalars.
   b.AddInstruction(HloInstruction::CreateBroadcast(
-      output_literal->shape(), literal_instruction,
+      output_literal.shape(), literal_instruction,
       /*broadcast_dimensions=*/{}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate({});
+  Literal result = Evaluate({});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
 }
 
 TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -401,11 +394,11 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<int64>(
       {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -423,10 +416,10 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR1<int64>({100, 200});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -435,17 +428,17 @@
   auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
   auto expected =
       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
-  ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
-                                               expected->shape()));
+  ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+                                               expected.shape()));
 
   HloInstruction* constant = b.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
-  b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+  b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 }
 
 TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -455,17 +448,17 @@
       {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
   auto expected = LiteralUtil::CreateR2WithLayout<float>(
       {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
-  ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
-                                                expected->shape()));
+  ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
+                                                expected.shape()));
 
   HloInstruction* constant = b.AddInstruction(
       HloInstruction::CreateConstant(std::move(input_literal)));
-  b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant));
+  b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 }
 
 PaddingConfig CreatePaddingConfig(
@@ -498,12 +491,12 @@
       shape, operand_instruction, padding_value_instruction, padding_config));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<int32>(
       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -525,7 +518,7 @@
       shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
   expected_array->Fill(kPadValue);
@@ -538,7 +531,7 @@
 
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -569,7 +562,7 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
   auto expected_array = absl::make_unique<Array2D<float>>(1, 5);
@@ -580,7 +573,7 @@
   (*expected_array)(0, 4) = 2.718f;
   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
 }
 
 TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -614,12 +607,12 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected_array = absl::make_unique<Array2D<float>>(0, 9);
   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -649,10 +642,11 @@
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
-                                             rhs_instruction, dot_dnums));
+                                             rhs_instruction, dot_dnums,
+                                             DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   // clang-format off
   auto expected_array = Array2D<float>({
@@ -664,7 +658,7 @@
   // clang-format on
   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -694,14 +688,15 @@
   dot_dnums.add_lhs_contracting_dimensions(0);
   dot_dnums.add_rhs_contracting_dimensions(0);
   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
-                                             rhs_instruction, dot_dnums));
+                                             rhs_instruction, dot_dnums,
+                                             DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -737,10 +732,11 @@
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
-                                             rhs_instruction, dot_dnums));
+                                             rhs_instruction, dot_dnums,
+                                             DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected_array = Array2D<float>({
       {22.f, 28.f},
@@ -750,7 +746,7 @@
   });
   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -788,17 +784,18 @@
   dnums.set_kernel_input_feature_dimension(1);
   dnums.add_kernel_spatial_dimensions(2);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -842,12 +839,13 @@
   ConvolutionDimensionNumbers dnums =
       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   Array4D<float> expected_array(1, 1, 4, 4);
   // clang-format off
@@ -860,7 +858,7 @@
   // clang-format on
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -925,12 +923,13 @@
   dnums.add_kernel_spatial_dimensions(3);
   dnums.add_kernel_spatial_dimensions(1);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   // clang-format off
   // Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -940,7 +939,7 @@
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
       use_bfloat16_ ? expected_array_bf16 : expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1002,12 +1001,13 @@
   dnums.add_kernel_spatial_dimensions(3);
   dnums.add_kernel_spatial_dimensions(1);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   // clang-format off
   // Result dimensions: [feature=1, height=1, batch=1, width=2]
@@ -1017,7 +1017,7 @@
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
       use_bfloat16_ ? expected_array_bf16 : expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1061,12 +1061,13 @@
   ConvolutionDimensionNumbers dnums =
       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   Array4D<float> expected_array(1, 1, 7, 7);
   expected_array.FillWithYX(Array2D<float>({
@@ -1080,7 +1081,7 @@
   }));
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1124,12 +1125,13 @@
   ConvolutionDimensionNumbers dnums =
       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   Array4D<float> expected_array(1, 1, 8, 8);
   expected_array.FillWithYX(Array2D<float>({
@@ -1144,7 +1146,7 @@
   }));
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest,
@@ -1195,12 +1197,13 @@
   ConvolutionDimensionNumbers dnums =
       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
 
-  const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
   b.AddInstruction(HloInstruction::CreateConvolve(
-      shape, lhs_instruction, rhs_instruction, window, dnums));
+      shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
+      window, dnums, DefaultPrecisionConfig(2)));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   Array4D<float> expected_array(1, 1, 9, 3);
   expected_array.FillWithYX(Array2D<float>({
@@ -1216,7 +1219,68 @@
   }));
   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
+TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) {
+  HloComputation::Builder b(TestName());
+  std::vector<int64> input_dims = {1, 2, 2, 4};
+  std::vector<int64> filter_dims = {2, 2, 2, 8};
+  Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
+  Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
+  // Tensorflow dimension numbers for 2D convolution.
+  ConvolutionDimensionNumbers dnums;
+  dnums.set_input_batch_dimension(0);
+  dnums.set_output_batch_dimension(0);
+  dnums.add_input_spatial_dimensions(1);
+  dnums.add_output_spatial_dimensions(1);
+  dnums.add_input_spatial_dimensions(2);
+  dnums.add_output_spatial_dimensions(2);
+  dnums.set_input_feature_dimension(3);
+  dnums.set_output_feature_dimension(3);
+  dnums.add_kernel_spatial_dimensions(0);
+  dnums.add_kernel_spatial_dimensions(1);
+  dnums.set_kernel_input_feature_dimension(2);
+  dnums.set_kernel_output_feature_dimension(3);
+
+  Window window;
+  WindowDimension dim;
+  dim.set_size(2);
+  dim.set_stride(1);
+  dim.set_padding_low(0);
+  dim.set_padding_high(0);
+  dim.set_window_dilation(1);
+  dim.set_base_dilation(1);
+  *window.add_dimensions() = dim;
+  *window.add_dimensions() = dim;
+
+  std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+  std::iota(input_elems.begin(), input_elems.end(), -7);
+  auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+  auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
+  HloInstruction* lhs_instruction =
+      b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
+
+  std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+  std::iota(filter_elems.begin(), filter_elems.end(), -31);
+  auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+  auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
+  HloInstruction* rhs_instruction =
+      b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
+
+  Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
+  b.AddInstruction(HloInstruction::CreateConvolve(
+      shape, lhs_instruction, rhs_instruction,
+      /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2)));
+  module().AddEntryComputation(b.Build());
+
+  Literal result = Evaluate();
+
+  Array4D<float> expected_array(1, 1, 1, 8);
+  expected_array.FillWithYX(
+      Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
+  auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1249,9 +1313,8 @@
   module().AddEntryComputation(b.Build());
 
   HloEvaluator hlo_eval;
-  std::unique_ptr<Literal> result =
-      hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
-  LiteralTestUtil::ExpectR0Equal<float>(kNumElements, *result);
+  Literal result = hlo_eval.Evaluate(reduce_instruction).ConsumeValueOrDie();
+  LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
 }
 
 // Reducing many numbers should be fast because it doesn't create
@@ -1328,11 +1391,11 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR1<float>({6, 18});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1380,10 +1443,10 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1437,10 +1500,10 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1448,7 +1511,7 @@
 
   // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
   std::vector<int64> input_dims(6, 4);
-  std::unique_ptr<Literal> arg_literal =
+  Literal arg_literal =
       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
 
   HloInstruction* arg_instruction =
@@ -1498,12 +1561,12 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
-  std::unique_ptr<Literal> result_literal =
+  Literal result_literal =
       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
 }
 
 TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1530,14 +1593,14 @@
                                                /*strides=*/{2, 3}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({
       {3},
       {19},
   });
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1564,14 +1627,14 @@
                                                       start_indices, {2, 3}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({
       {2, 3, 4},
       {6, 7, 8},
   });
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 // Verifies that the HloEvaluator's implementation goes along with existing
@@ -1600,14 +1663,14 @@
                                                       start_indices, {2, 3}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<float>({
       {2, 3, 4},
       {6, 7, 8},
   });
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1637,14 +1700,14 @@
       shape, operand, update, start_indices));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<double>({
       {1, -2, -3},
       {5, -6, -7},
   });
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1673,14 +1736,14 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto expected = LiteralUtil::CreateR2<double>({
       {1, 2, 3},
       {5, 6, 7},
   });
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1712,16 +1775,14 @@
 
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   auto result_inner_literal =
       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
-  auto expected = LiteralUtil::MakeTuple({
-      result_inner_literal.get(),
-      result_inner_literal.get(),
-  });
+  auto expected =
+      LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, Reverse) {
@@ -1752,7 +1813,7 @@
   b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
   module().AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = Evaluate();
+  Literal result = Evaluate();
 
   // clang-format off
   auto expected = LiteralUtil::CreateR4FromArray4D<float>({
@@ -1774,7 +1835,7 @@
   });
   // clang-format on
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1790,12 +1851,13 @@
 
   // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
   HloEvaluator evaluator;
+  Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
+  Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
   auto result = evaluator.EvaluateWithSubstitutions(
-      add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
-            {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+      add, {{param0, &param0_literal}, {square, &square_literal}});
   TF_ASSERT_OK(result.status());
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+      LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
 }
 
 // Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1815,11 +1877,12 @@
 
   // Evaluate add with square = {10, 20, 30, 40}.
   HloEvaluator evaluator;
-  auto result = evaluator.EvaluateWithSubstitutions(
-      add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
+  Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
+  auto result =
+      evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}});
   TF_ASSERT_OK(result.status());
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+      LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result.ValueOrDie()));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1838,12 +1901,12 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
-      *Evaluate({operand.get(), start_indices.get()})));
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+      Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1862,12 +1925,12 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
-      *Evaluate({operand.get(), start_indices.get()})));
+      LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+      Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1886,14 +1949,13 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR3<int32>(
+      LiteralUtil::CreateR3<int32>(
           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
-      *Evaluate({operand.get(), start_indices.get()})));
+      Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1912,15 +1974,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+                             Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest,
@@ -1940,15 +2001,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+                             Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1967,12 +2027,11 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+  Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{5}}),
+                                     Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1991,13 +2050,12 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
+                             Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2016,11 +2074,10 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
-  EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+  Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{}, {}}),
+                                     Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2040,12 +2097,12 @@
 )";
   ParseAndVerifyModule(hlo_text);
 
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
-  std::unique_ptr<Literal> start_indices =
+  Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+  Literal start_indices =
       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
-                             *Evaluate({operand.get(), start_indices.get()})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
+                             Evaluate({&operand, &start_indices})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
@@ -2070,15 +2127,13 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
@@ -2103,15 +2158,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates =
       LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
@@ -2137,15 +2191,13 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
@@ -2171,15 +2223,13 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_F32) {
@@ -2205,17 +2255,15 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+  Literal operand = LiteralUtil::CreateR2<float>(
       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({2, 1});
-  std::unique_ptr<Literal> updates =
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+  Literal updates =
       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>(
+      LiteralUtil::CreateR2<float>(
           {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()}),
-      ErrorSpec{0.1, 0.01}));
+      Evaluate({&operand, &scatter_indices, &updates}), ErrorSpec{0.1, 0.01}));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
@@ -2241,15 +2289,13 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({1, 1});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
@@ -2275,15 +2321,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+  Literal updates = LiteralUtil::CreateR3<int32>(
       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      LiteralUtil::CreateR2<int32>({{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
+      Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
@@ -2308,21 +2353,18 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
-  std::unique_ptr<Literal> expected =
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+  Literal expected =
       LiteralUtil::CreateR3<int32>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
                                     {{-40, 40}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *expected,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      expected, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest,
@@ -2348,21 +2390,18 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
-  std::unique_ptr<Literal> expected =
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+  Literal expected =
       LiteralUtil::CreateR3<int32>({{{-20, 20}, {-10, 10}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},      //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *expected,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      expected, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
@@ -2387,16 +2426,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({1, 1});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
-  std::unique_ptr<Literal> expected =
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+  Literal expected =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *expected,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      expected, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
@@ -2421,17 +2458,14 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
-  std::unique_ptr<Literal> expected =
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+  Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+  Literal expected =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *expected,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      expected, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
@@ -2456,13 +2490,11 @@
 }
 )";
   ParseAndVerifyModule(hlo_text);
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+  Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *operand,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      operand, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
@@ -2489,16 +2521,13 @@
 )";
   ParseAndVerifyModule(hlo_text);
 
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
-  std::unique_ptr<Literal> scatter_indices =
+  Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+  Literal scatter_indices =
       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::CreateR1<int32>({10, 61, 32});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+  Literal expected = LiteralUtil::CreateR1<int32>({10, 61, 32});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *expected,
-      *Evaluate({operand.get(), scatter_indices.get(), updates.get()})));
+      expected, Evaluate({&operand, &scatter_indices, &updates})));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -2535,11 +2564,29 @@
 )";
   ParseAndVerifyModule(hlo_text);
 
-  std::unique_ptr<Literal> arg = LiteralUtil::CreateR1<bfloat16>(
+  Literal arg = LiteralUtil::CreateR1<bfloat16>(
       {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *Evaluate({arg.get()})));
+  Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
+}
+
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+  // Regression test for b/114735354.
+  const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+  arg = f32[2,2,2]{0,1,2} parameter(0)
+  ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+  ParseAndVerifyModule(hlo_text);
+
+  Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+      {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+      LayoutUtil::MakeLayout({0, 1, 2}));
+  Literal actual = Evaluate({&arg});
+  EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
 }
 
 INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index cb27e13..8fb17a0 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -246,32 +246,21 @@
   Status HandleConvert(HloInstruction* convert) override {
     const HloInstruction* operand = convert->operand(0);
     TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+    TF_ASSIGN_OR_RETURN(Literal result,
                         parent_->GetEvaluatedLiteralFor(operand).Convert(
                             convert->shape().element_type()));
-
-    if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
-      parent_->evaluated_[convert] = std::move(result);
-    } else {
-      parent_->evaluated_[convert] =
-          result->Relayout(convert->shape().layout());
-    }
+    parent_->evaluated_[convert] = std::move(result);
     return Status::OK();
   }
 
   Status HandleBitcastConvert(HloInstruction* convert) override {
     const HloInstruction* operand = convert->operand(0);
     TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+    TF_ASSIGN_OR_RETURN(Literal result,
                         parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
                             convert->shape().element_type()));
 
-    if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
-      parent_->evaluated_[convert] = std::move(result);
-    } else {
-      parent_->evaluated_[convert] =
-          result->Relayout(convert->shape().layout());
-    }
+    parent_->evaluated_[convert] = std::move(result);
     return Status::OK();
   }
 
@@ -978,10 +967,10 @@
         << ShapeUtil::HumanString(inferred_return_shape);
 
     const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
-    auto result = absl::make_unique<Literal>(result_shape);
+    Literal result(result_shape);
 
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> out_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> out_index) {
           std::vector<int64> from_index(out_index.begin(), out_index.end());
           for (const int64 dim : reverse_dimensions) {
             from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim];
@@ -1021,9 +1010,10 @@
     CHECK_EQ(num_spatial_dims + 2, lhs_rank);
     CHECK_EQ(num_spatial_dims + 2, rhs_rank);
 
-    TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
-                        ShapeInference::InferConvolveShape(lhs_shape, rhs_shape,
-                                                           window, dnums));
+    TF_ASSIGN_OR_RETURN(
+        auto inferred_return_shape,
+        ShapeInference::InferConvolveShape(
+            lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums));
     CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
         << "return shape set to: " << ShapeUtil::HumanString(result_shape)
         << " but is inferred to be: "
@@ -1046,9 +1036,12 @@
     auto lhs_literal_data = lhs_literal.data<ReturnT>();
     auto rhs_literal_data = rhs_literal.data<ReturnT>();
 
+    int64 feature_group_count = conv->feature_group_count();
+
     auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
                  &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
-                 rhs_literal_data](absl::Span<const int64> out_index) {
+                 rhs_literal_data,
+                 feature_group_count](const absl::Span<const int64> out_index) {
       // Dimension number applicable for input (lhs).
       const int64 input_batch_dim = dnums.input_batch_dimension();
       const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1059,7 +1052,22 @@
       const int64 output_batch_dim = dnums.output_batch_dimension();
       const int64 output_z_dim = dnums.output_feature_dimension();
 
-      const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+      const int64 input_z_size =
+          ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+      // The size of an input feature group.
+      const int64 input_feature_group_size = input_z_size / feature_group_count;
+
+      const int64 output_z_size =
+          ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
+      // The output feature dimension is a concatenation of convolution results
+      // from the different groups.
+      const int64 output_feature_group_size =
+          output_z_size / feature_group_count;
+
+      // Calculate the group index to which the current output index
+      // belongs.
+      const int64 feature_group_index =
+          out_index[output_z_dim] / output_feature_group_size;
 
       ElementwiseT result_val = static_cast<ElementwiseT>(0);
       DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1067,7 +1075,10 @@
 
       // Convolve input feature with kernel.
       do {
-        for (int64 iz = 0; iz < z_size; ++iz) {
+        for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
+          const int64 iz =
+              feature_group_index * input_feature_group_size + rhs_iz;
+
           int64 lhs_linear_index = 0;
           lhs_linear_index += out_index[output_batch_dim] *
                               lhs_dim_multipliers[input_batch_dim];
@@ -1076,7 +1087,7 @@
           int64 rhs_linear_index = 0;
           rhs_linear_index += out_index[output_z_dim] *
                               rhs_dim_multipliers[kernel_output_z_dim];
-          rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
+          rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim];
 
           // Find corresponding spatial dimension index for input (lhs).
           for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
@@ -1135,8 +1146,8 @@
       return static_cast<ReturnT>(result_val);
     };
 
-    auto result = absl::make_unique<Literal>(result_shape);
-    TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
+    Literal result(result_shape);
+    TF_RETURN_IF_ERROR(result.PopulateParallel<ReturnT>(func));
 
     parent_->evaluated_[conv] = std::move(result);
     return Status::OK();
@@ -1209,9 +1220,9 @@
       }
     }
 
-    auto result = absl::make_unique<Literal>(dot->shape());
+    Literal result(dot->shape());
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> result_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> result_index) {
           ElementwiseT result_val = static_cast<ElementwiseT>(0);
 
           for (int64 i = 0; i < result_index.size(); i++) {
@@ -1258,8 +1269,8 @@
     // Create new HLO of padded shape with padding value.
     ReturnT scalar =
         parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({});
-    auto result = absl::make_unique<Literal>(pad->shape());
-    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+    Literal result(pad->shape());
+    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
         [&scalar](absl::Span<const int64> multi_index) { return scalar; }));
 
     const Literal& evaluated_operand =
@@ -1267,7 +1278,7 @@
 
     std::vector<int64> input_index(ShapeUtil::Rank(evaluated_operand.shape()),
                                    0);
-    std::vector<int64> target_index(ShapeUtil::Rank(result->shape()), 0);
+    std::vector<int64> target_index(ShapeUtil::Rank(result.shape()), 0);
 
     // Loop through each element of the operand, assign them to the
     // corresponding index of the resulting padded literal.
@@ -1289,8 +1300,8 @@
           return true;
         }
       }
-      result->Set<ReturnT>(target_index,
-                           evaluated_operand.Get<ReturnT>(input_index));
+      result.Set<ReturnT>(target_index,
+                          evaluated_operand.Get<ReturnT>(input_index));
       return true;
     };
 
@@ -1417,16 +1428,16 @@
   }
 
   template <typename NativeT>
-  StatusOr<std::unique_ptr<Literal>> MapImpl(HloInstruction* map) {
+  StatusOr<Literal> MapImpl(HloInstruction* map) {
     auto operands = map->operands();
     HloComputation* computation = map->to_apply();
 
-    auto result = absl::make_unique<Literal>(map->shape());
+    Literal result(map->shape());
 
     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
-          std::vector<std::unique_ptr<Literal>> arg_literals;
+        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+          std::vector<Literal> arg_literals;
           arg_literals.reserve(operands.size());
 
           // Construct scalar literal parameters to be passed to the map
@@ -1441,16 +1452,14 @@
             arg_literals.push_back(std::move(curr_val_literal));
           }
 
-          std::unique_ptr<Literal> computed_result =
-              embedded_evaluator
-                  .Evaluate<std::unique_ptr<Literal>>(*computation,
-                                                      arg_literals)
+          Literal computed_result =
+              embedded_evaluator.Evaluate<Literal>(*computation, arg_literals)
                   .ConsumeValueOrDie();
           // Clear visit states so that the we can use the evaluate again on
           // the same computation.
           embedded_evaluator.ResetVisitStates();
 
-          return computed_result->Get<ReturnT>({});
+          return computed_result.Get<ReturnT>({});
         }));
     return std::move(result);
   }
@@ -1535,9 +1544,9 @@
                 [](const ReturnT& a, const ReturnT& b) {
                   return SafeLess<ReturnT>(a, b);
                 });
-      auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
-      result_literal->PopulateR1(absl::Span<const ReturnT>(result_data));
-      VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+      Literal result_literal(keys_literal.shape());
+      result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
+      VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
       return result_literal;
     };
 
@@ -1546,16 +1555,16 @@
     } else {
       // For R2 sort, the desired semantics are to sort each matrix row
       // independently.
-      auto result_literal = absl::make_unique<Literal>(keys_literal.shape());
+      Literal result_literal(keys_literal.shape());
       int64 r1_length = keys->shape().dimensions(1);
       for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
         TF_ASSIGN_OR_RETURN(auto r1_slice,
                             keys_literal.Slice({row, 0}, {row + 1, r1_length})
-                                ->Reshape({r1_length}));
-        auto r1_result = sort_r1(*r1_slice);
-        TF_ASSIGN_OR_RETURN(r1_result, r1_result->Reshape({1, r1_length}));
-        TF_RETURN_IF_ERROR(result_literal->CopySliceFrom(
-            *r1_result, {0, 0}, {row, 0}, {1, r1_length}));
+                                .Reshape({r1_length}));
+        auto r1_result = sort_r1(r1_slice);
+        TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
+        TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+            r1_result, {0, 0}, {row, 0}, {1, r1_length}));
       }
       parent_->evaluated_[sort] = std::move(result_literal);
     }
@@ -1629,9 +1638,9 @@
     }
 
     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
-    absl::InlinedVector<std::unique_ptr<Literal>, 1> results(num_args);
+    absl::InlinedVector<Literal, 1> results(num_args);
     for (int64 i = 0; i < num_args; ++i) {
-      results[i] = absl::make_unique<Literal>(result_shape);
+      results[i] = Literal(result_shape);
     }
 
     Status eval_status;
@@ -1645,7 +1654,7 @@
     }
 
     for (int64 input = 0; input < num_args; ++input) {
-      TF_RETURN_IF_ERROR(results[input]->Populate<ReturnT>(
+      TF_RETURN_IF_ERROR(results[input].Populate<ReturnT>(
           [&](absl::Span<const int64> multi_index) {
             if (!eval_status.ok()) {
               return init_scalars[input];
@@ -1681,8 +1690,7 @@
               }
 
               // Evaluate computation with specified literal operands.
-              absl::InlinedVector<std::unique_ptr<Literal>, 1>
-                  embedded_operands;
+              absl::InlinedVector<Literal, 1> embedded_operands;
               for (ReturnT value : result_values) {
                 embedded_operands.push_back(
                     LiteralUtil::CreateR0<ReturnT>(value));
@@ -1695,11 +1703,9 @@
                   embedded_operands.size());
               std::transform(embedded_operands.begin(), embedded_operands.end(),
                              embedded_operands_ptrs.begin(),
-                             [](const std::unique_ptr<Literal>& ptr) {
-                               return ptr.get();
-                             });
+                             [](Literal& literal) { return &literal; });
 
-              TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result,
+              TF_ASSIGN_OR_RETURN(Literal computed_result,
                                   embedded_evaluator.Evaluate<const Literal*>(
                                       *function, embedded_operands_ptrs));
               // Clear visit states so that we can use the evaluator again on
@@ -1707,10 +1713,10 @@
               embedded_evaluator.ResetVisitStates();
               // Assign computed result to result_val.
               if (!has_tuple_output) {
-                result_values[0] = computed_result->Get<ReturnT>({});
+                result_values[0] = computed_result.Get<ReturnT>({});
               } else {
                 for (int64 i = 0; i < num_args; ++i) {
-                  result_values[i] = computed_result->Get<ReturnT>(
+                  result_values[i] = computed_result.Get<ReturnT>(
                       /*multi_index=*/{}, /*shape_index=*/{i});
                 }
               }
@@ -1726,9 +1732,9 @@
     if (!has_tuple_output) {
       parent_->evaluated_[reduce] = std::move(results[0]);
     } else {
-      auto tuple_result = absl::make_unique<Literal>(reduce->shape());
+      Literal tuple_result(reduce->shape());
       for (int64 i = 0; i < num_args; ++i) {
-        TF_CHECK_OK(tuple_result->MoveFrom(std::move(*results[i]), {i}));
+        TF_CHECK_OK(tuple_result.MoveFrom(std::move(results[i]), {i}));
       }
       parent_->evaluated_[reduce] = std::move(tuple_result);
     }
@@ -1759,10 +1765,10 @@
     TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
     auto init_scalar = init_literal.Get<ReturnT>({});
 
-    auto result = absl::make_unique<Literal>(select_and_scatter->shape());
+    Literal result(select_and_scatter->shape());
 
     // Initialize result array with the init value.
-    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(
         [&](absl::Span<const int64> output_index) { return init_scalar; }));
 
     std::vector<int64> window_dimension_sizes;
@@ -1812,15 +1818,14 @@
               selected_val = curr_val;
               selected_index = operand_index;
             }
-            curr_val_literal->Set({}, curr_val);
-            selected_val_literal->Set({}, *selected_val);
-            std::unique_ptr<Literal> computed_result =
+            curr_val_literal.Set({}, curr_val);
+            selected_val_literal.Set({}, *selected_val);
+            Literal computed_result =
                 embedded_evaluator
                     .Evaluate<const Literal*>(
-                        *select,
-                        {selected_val_literal.get(), curr_val_literal.get()})
+                        *select, {&selected_val_literal, &curr_val_literal})
                     .ConsumeValueOrDie();
-            bool selected = !computed_result->Get<bool>({});
+            bool selected = !computed_result.Get<bool>({});
             if (selected) {
               selected_val = curr_val;
               selected_index = operand_index;
@@ -1834,16 +1839,16 @@
             if (std::equal(operand_index.begin(), operand_index.end(),
                            selected_index->begin())) {
               auto source = source_literal.Get<ReturnT>(source_index);
-              auto scattered = result->Get<ReturnT>(operand_index);
-              source_literal_scatter->Set({}, source);
-              scattered_literal->Set({}, scattered);
-              std::unique_ptr<Literal> computed_result =
+              auto scattered = result.Get<ReturnT>(operand_index);
+              source_literal_scatter.Set({}, source);
+              scattered_literal.Set({}, scattered);
+              Literal computed_result =
                   embedded_evaluator
-                      .Evaluate<const Literal*>(*scatter,
-                                                {source_literal_scatter.get(),
-                                                 scattered_literal.get()})
+                      .Evaluate<const Literal*>(
+                          *scatter,
+                          {&source_literal_scatter, &scattered_literal})
                       .ConsumeValueOrDie();
-              result->Set(operand_index, computed_result->Get<ReturnT>({}));
+              result.Set(operand_index, computed_result.Get<ReturnT>({}));
               // Clear visit states so that the we can use the evaluator again
               // on the same computation.
               embedded_evaluator.ResetVisitStates();
@@ -1894,10 +1899,10 @@
     DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape()));
 
     HloEvaluator embedded_evaluator(parent_->max_loop_iterations_);
-    auto result = absl::make_unique<Literal>(reduce_window->shape());
+    Literal result(reduce_window->shape());
     // For each resulting dimension, calculate and assign computed value.
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> output_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> output_index) {
           ReturnT result_val = init_scalar;
 
           std::fill(window_index.begin(), window_index.end(), 0);
@@ -1913,18 +1918,17 @@
                     LiteralUtil::CreateR0<ReturnT>(curr_val);
                 const auto result_val_literal =
                     LiteralUtil::CreateR0<ReturnT>(result_val);
-                std::unique_ptr<Literal> computed_result =
+                Literal computed_result =
                     embedded_evaluator
                         .Evaluate<const Literal*>(
-                            *function,
-                            {result_val_literal.get(), curr_val_literal.get()})
+                            *function, {&result_val_literal, &curr_val_literal})
                         .ConsumeValueOrDie();
 
                 // Clear visit states so that the we can use the evaluate again
                 // on the same computation.
                 embedded_evaluator.ResetVisitStates();
 
-                result_val = computed_result->Get<ReturnT>({});
+                result_val = computed_result.Get<ReturnT>({});
               });
 
           return result_val;
@@ -1939,7 +1943,7 @@
   // literal (if there is one) to `reshaped_indices`.
   StatusOr<std::reference_wrapper<const Literal>> ReshapedScatterIndices(
       int64 index_vector_dim, const Literal& indices,
-      std::unique_ptr<Literal>* reshaped_indices) {
+      Literal* reshaped_indices) {
     if (indices.shape().dimensions_size() != index_vector_dim) {
       return std::cref(indices);
     }
@@ -1948,7 +1952,7 @@
                                  indices.shape().dimensions().end());
     new_shape.push_back(1);
     TF_ASSIGN_OR_RETURN(*reshaped_indices, indices.Reshape(new_shape));
-    return std::cref(**reshaped_indices);
+    return std::cref(*reshaped_indices);
   }
 
   // Returns an ShapeUtil::IndexIterationSpace that iterates over the update
@@ -2208,7 +2212,7 @@
         scatter->scatter_dimension_numbers();
     const Literal& operand =
         parent_->GetEvaluatedLiteralFor(scatter->operand(0));
-    std::unique_ptr<Literal> reshaped_scatter_indices;
+    Literal reshaped_scatter_indices;
     TF_ASSIGN_OR_RETURN(const Literal& scatter_indices,
                         ReshapedScatterIndices(dim_numbers.index_vector_dim(),
                                                parent_->GetEvaluatedLiteralFor(
@@ -2238,7 +2242,7 @@
 
     // Initialize the result with the operand. This makes it easier to handle
     // the updates even when the indices are repeated.
-    std::unique_ptr<Literal> result = operand.CloneToUnique();
+    Literal result = operand.Clone();
     HloEvaluator embedded_evaluator;
     auto scatter_inner_loop_body =
         [&](absl::Span<const int64> update_window_index,
@@ -2277,19 +2281,19 @@
       }
 
       auto result_value_literal =
-          LiteralUtil::CreateR0<ReturnT>(result->Get<ReturnT>(input_index));
+          LiteralUtil::CreateR0<ReturnT>(result.Get<ReturnT>(input_index));
       auto update_value_literal =
           LiteralUtil::CreateR0<ReturnT>(updates.Get<ReturnT>(update_index));
-      std::unique_ptr<Literal> updated_result =
+      Literal updated_result =
           embedded_evaluator
               .Evaluate<const Literal*>(
                   *scatter->to_apply(),
-                  {result_value_literal.get(), update_value_literal.get()})
+                  {&result_value_literal, &update_value_literal})
               .ConsumeValueOrDie();
       // Clear visit states so that the we can use the evaluate again on the
       // same computation.
       embedded_evaluator.ResetVisitStates();
-      result->Set<ReturnT>(input_index, updated_result->Get<ReturnT>({}));
+      result.Set<ReturnT>(input_index, updated_result.Get<ReturnT>({}));
       return true;
     };
 
@@ -2337,9 +2341,8 @@
       return operand_literal.Get<ReturnT>(operand_index);
     };
 
-    auto result = LiteralUtil::CreateFromDimensions(
-        shape.element_type(), AsInt64Slice(shape.dimensions()));
-    TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+    Literal result(shape);
+    TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
     parent_->evaluated_[slice] = std::move(result);
     return Status::OK();
   }
@@ -2553,7 +2556,7 @@
     if (ShapeUtil::Rank(iota->shape()) > 1) {
       TF_ASSIGN_OR_RETURN(
           parent_->evaluated_[iota],
-          result->Broadcast(iota->shape(), {iota->iota_dimension()}));
+          result.Broadcast(iota->shape(), {iota->iota_dimension()}));
     } else {
       TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1);
       parent_->evaluated_[iota] = std::move(result);
@@ -2623,9 +2626,9 @@
   }
 
   template <typename IndexT>
-  StatusOr<std::unique_ptr<Literal>> DynamicSlice(
-      const Literal& operand_literal, const Literal& start_indices_literal,
-      const Shape& result_shape) {
+  StatusOr<Literal> DynamicSlice(const Literal& operand_literal,
+                                 const Literal& start_indices_literal,
+                                 const Shape& result_shape) {
     auto start_indices_typed = start_indices_literal.data<IndexT>();
     std::vector<int64> start(start_indices_typed.begin(),
                              start_indices_typed.end());
@@ -2638,9 +2641,9 @@
     }
 
     std::vector<int64> operand_indices(start.size());
-    auto result = absl::make_unique<Literal>(result_shape);
+    Literal result(result_shape);
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
           for (int64 i = 0; i < operand_indices.size(); ++i) {
             CHECK_GE(multi_index[i] + start[i], 0);
             operand_indices[i] = multi_index[i] + start[i];
@@ -2654,12 +2657,12 @@
   }
 
   template <typename IndexT>
-  StatusOr<std::unique_ptr<Literal>> DynamicUpdateSlice(
-      const Literal& operand_literal, const Literal& update_literal,
-      const Literal& start_indices_literal) {
-    auto result = operand_literal.CloneToUnique();
+  StatusOr<Literal> DynamicUpdateSlice(const Literal& operand_literal,
+                                       const Literal& update_literal,
+                                       const Literal& start_indices_literal) {
+    auto result = operand_literal.Clone();
     auto start_indices_typed = start_indices_literal.data<IndexT>();
-    const auto rank = ShapeUtil::Rank(result->shape());
+    const auto rank = ShapeUtil::Rank(result.shape());
     std::vector<int64> start(start_indices_typed.begin(),
                              start_indices_typed.end());
     // Clamp the update start indices so the slice is in-bounds w.r.t the
@@ -2667,15 +2670,15 @@
     for (int64 i = 0; i < rank; ++i) {
       start[i] = std::min<int64>(
           std::max<int64>(0, start[i]),
-          result->shape().dimensions(i) - update_literal.shape().dimensions(i));
+          result.shape().dimensions(i) - update_literal.shape().dimensions(i));
     }
     std::vector<int64> result_index(rank, 0);
 
     auto func = [&](absl::Span<const int64> update_index) {
       std::transform(update_index.begin(), update_index.end(), start.begin(),
                      result_index.begin(), std::plus<int64>());
-      result->Set<ReturnT>(result_index,
-                           update_literal.Get<ReturnT>(update_index));
+      result.Set<ReturnT>(result_index,
+                          update_literal.Get<ReturnT>(update_index));
       return true;
     };
 
@@ -2688,7 +2691,7 @@
     return std::move(result);
   }
 
-  StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOp(
+  StatusOr<Literal> ElementWiseUnaryOp(
       HloInstruction* instruction,
       const std::function<ElementwiseT(ElementwiseT)>& unary_op) {
     const Literal& operand_literal =
@@ -2701,7 +2704,7 @@
     return std::move(result_literal);
   }
 
-  StatusOr<std::unique_ptr<Literal>> ElementWiseBinaryOp(
+  StatusOr<Literal> ElementWiseBinaryOp(
       HloInstruction* instruction,
       const std::function<ElementwiseT(ElementwiseT, ElementwiseT)>&
           binary_op) {
@@ -2723,10 +2726,10 @@
     const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
 
-    auto result = absl::make_unique<Literal>(shape);
+    Literal result(shape);
 
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
           return ConvertBinaryFunction(binary_op)(
               lhs_literal.Get<ReturnT>(multi_index),
               rhs_literal.Get<ReturnT>(multi_index));
@@ -2735,7 +2738,7 @@
   }
 
   template <typename LhsType, typename RhsType, typename EhsType>
-  StatusOr<std::unique_ptr<Literal>> ElementwiseTernaryOp(
+  StatusOr<Literal> ElementwiseTernaryOp(
       HloInstruction* instruction,
       const std::function<ReturnT(LhsType, RhsType, EhsType)>& ternary_op) {
     const auto shape = instruction->shape();
@@ -2760,10 +2763,10 @@
     const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
     const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs);
 
-    auto result = absl::make_unique<Literal>(shape);
+    Literal result(shape);
 
     TF_RETURN_IF_ERROR(
-        result->Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
+        result.Populate<ReturnT>([&](absl::Span<const int64> multi_index) {
           return ternary_op(lhs_literal.Get<LhsType>(multi_index),
                             rhs_literal.Get<RhsType>(multi_index),
                             ehs_literal.Get<EhsType>(multi_index));
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 3041d94..287ba84 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -120,12 +120,23 @@
   std::function<NodeFilterResult(const HloInstruction* instr)> filter_;
 };
 
+// We arbitrarily set this as the boundary between "large" and "small"
+// instructions.
+bool IsSmall(const HloInstruction* instr) {
+  if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
+      ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
+    return true;
+  }
+  return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
+}
+
 // Node color schemes, used by NodeColorAttributes.
 enum ColorScheme {
   kBlue,
   kBrown,
   kDarkBlue,
   kDarkGreen,
+  kDarkOrange,
   kDarkRed,
   kGray,
   kGreen,
@@ -158,6 +169,10 @@
       return NodeColors{"filled", "#1565c0", "#003c8f", "white"};
     case kDarkGreen:
       return NodeColors{"filled", "#2e7d32", "#005005", "white"};
+    case kDarkOrange:
+      // This is more of a "medium" orange, made to look close to kOrange;
+      // there's probably room for a darker weight if desired.
+      return NodeColors{"filled", "#ffb74d", "#c88719", "black"};
     case kDarkRed:
       return NodeColors{"filled", "#b71c1c", "#7f0000", "white"};
     case kGray:
@@ -454,9 +469,8 @@
   string graph_label =
       StrCat(label_, "<br/>Computation ", computation_->name());
   if (computation_->IsFusionComputation()) {
-    StrAppend(&graph_label,
-              StrCat(" (in fusion instruction ",
-                     computation_->FusionInstruction()->name(), ")"));
+    StrAppend(&graph_label, " (in fusion instruction ",
+              computation_->FusionInstruction()->name(), ")");
   }
   if (profile_ != nullptr) {
     auto cycles = profile_->total_cycles_executed(*computation_);
@@ -893,7 +907,10 @@
     sharding_colors_.emplace(instr->sharding(), color);
     return color;
   }
-  const auto kParameterColor = kOrange;
+
+  // Choose different weights of orange for small vs large parameters.  This
+  // distinction is often important, especially in fusion nodes.
+  auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange;
 
   // Special case: If this instruction has a parameter merged into it, paint it
   // the same color as a parameter.  Unless the merged-in parameter is a
@@ -905,7 +922,7 @@
                            ShouldMergeIntoUsers(operand) &&
                            TryGetFusionParameterConstant(operand) == nullptr;
                   })) {
-    return kParameterColor;
+    return parameter_color;
   }
 
   // Pick different colors or shapes for instructions which are particularly
@@ -1015,7 +1032,7 @@
     case HloOpcode::kReducePrecision:
       return kRed;
     case HloOpcode::kParameter:
-      return kParameterColor;
+      return parameter_color;
     case HloOpcode::kBatchNormGrad:
     case HloOpcode::kBatchNormInference:
     case HloOpcode::kBatchNormTraining:
@@ -1160,20 +1177,6 @@
   return StrJoin(lines, "<br/>");
 }
 
-// Gets the total number of array elements in the given shape.  For tuples, this
-// is the sum of all the sizes of all of the array elements recursively in the
-// tuple.
-static int64 TotalElementsInShape(const Shape& shape) {
-  int64 elems = 0;
-  ShapeUtil::ForEachSubshape(
-      shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) {
-        if (ShapeUtil::IsArray(subshape)) {
-          elems += ShapeUtil::ElementsIn(subshape);
-        }
-      });
-  return elems;
-}
-
 void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
   auto add_edge = [&](const HloInstruction* from, const HloInstruction* to,
                       int64 operand_num, bool control_edge = false) {
@@ -1196,14 +1199,11 @@
     }
 
     // We print "small" arrays using a hollow arrowhead and "large" arrays using
-    // a filled arrowhead.  For now, we use an arbitrary cutoff for what "big"
-    // means.
-    bool is_big_array = TotalElementsInShape(from->shape()) >= 4096;
-
+    // a filled arrowhead.
     constexpr char kEdgeFmt[] =
         R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)";
     edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to),
-                               (is_big_array ? "normal" : "empty"),
+                               (IsSmall(from) ? "empty" : "normal"),
                                from->name(), to->name(), edge_label));
   };
 
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6d13f85..e905f29 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -250,7 +250,7 @@
       TF_RET_CHECK(proto.has_literal());
       TF_ASSIGN_OR_RETURN(auto literal,
                           Literal::CreateFromProto(proto.literal()));
-      instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+      instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
       break;
     }
     case HloOpcode::kFusion: {
@@ -341,17 +341,21 @@
                                             source_target_pairs);
       break;
     }
-    case HloOpcode::kConvolution:
+    case HloOpcode::kConvolution: {
       TF_RET_CHECK(proto.operand_ids_size() == 2)
           << "Convolution instruction should have 2 operands but sees "
           << proto.operand_ids_size();
       TF_RET_CHECK(proto.has_window());
       TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+      PrecisionConfig precision_config = proto.precision_config();
+      precision_config.mutable_operand_precision()->Resize(
+          proto.operand_ids_size(), PrecisionConfig::DEFAULT);
       instruction = CreateConvolve(
-          proto.shape(), operands(0), operands(1), proto.window(),
-          proto.convolution_dimension_numbers(),
-          std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
+          proto.shape(), operands(0), operands(1),
+          std::max<int64>(proto.feature_group_count(), 1), proto.window(),
+          proto.convolution_dimension_numbers(), precision_config);
       break;
+    }
     case HloOpcode::kReduceWindow:
       TF_RET_CHECK(proto.operand_ids_size() == 2)
           << "ReduceWindow instruction should have 2 operands but sees "
@@ -447,6 +451,28 @@
           << proto.dimensions_size();
       instruction = CreateIota(proto.shape(), proto.dimensions(0));
       break;
+    case HloOpcode::kDot: {
+      TF_RET_CHECK(proto.has_dot_dimension_numbers())
+          << "Dot instruction should have dot_dimension_numbers.";
+      TF_RET_CHECK(proto.operand_ids_size() == 2)
+          << "Dot instruction should have 2 operands but sees "
+          << proto.operand_ids_size();
+      PrecisionConfig precision_config = proto.precision_config();
+      precision_config.mutable_operand_precision()->Resize(
+          proto.operand_ids_size(), PrecisionConfig::DEFAULT);
+      instruction = absl::make_unique<HloDotInstruction>(
+          proto.shape(), operands(0), operands(1),
+          proto.dot_dimension_numbers(), precision_config);
+      break;
+    }
+    case HloOpcode::kDomain:
+      TF_RET_CHECK(proto.operand_ids_size() == 1)
+          << "Domain instruction should have 1 operands but sees "
+          << proto.operand_ids_size();
+      instruction = absl::make_unique<HloDomainInstruction>(
+          proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
+          /*user_side_metadata=*/nullptr);
+      break;
     default: {
       instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
       for (const int64 operand_id : proto.operand_ids()) {
@@ -468,6 +494,9 @@
               computation_map.at(computation_id));
         }
       }
+      TF_RET_CHECK(!proto.has_precision_config())
+          << instruction->opcode() << proto.DebugString();
+      TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
       break;
     }
   }
@@ -476,12 +505,7 @@
   instruction->SetAndSanitizeName(proto.name());
   instruction->metadata_ = proto.metadata();
   instruction->backend_config_ = proto.backend_config();
-  instruction->precision_config_ = proto.precision_config();
-
-  if (proto.has_dot_dimension_numbers()) {
-    instruction->dot_dimension_numbers_ =
-        absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers());
-  }
+  instruction->unique_id_ = proto.id();
 
   if (proto.has_sharding()) {
     TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -504,7 +528,7 @@
 }
 
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
-    std::unique_ptr<Literal> literal) {
+    Literal literal) {
   return absl::make_unique<HloConstantInstruction>(std::move(literal));
 }
 
@@ -552,7 +576,6 @@
     case HloOpcode::kCopy:
     case HloOpcode::kCos:
     case HloOpcode::kClz:
-    case HloOpcode::kDomain:
     case HloOpcode::kExp:
     case HloOpcode::kExpm1:
     case HloOpcode::kFloor:
@@ -584,7 +607,6 @@
     case HloOpcode::kAtan2:
     case HloOpcode::kDivide:
     case HloOpcode::kComplex:
-    case HloOpcode::kDot:
     case HloOpcode::kEq:
     case HloOpcode::kGe:
     case HloOpcode::kGt:
@@ -643,10 +665,12 @@
 
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-    const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count) {
+    int64 feature_group_count, const Window& window,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config) {
   return absl::make_unique<HloConvolutionInstruction>(
-      shape, lhs, rhs, window, dimension_numbers, feature_group_count);
+      shape, lhs, rhs, feature_group_count, window, dimension_numbers,
+      precision_config);
 }
 
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
@@ -658,30 +682,10 @@
 
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-    const DotDimensionNumbers& dimension_numbers) {
-  auto instruction =
-      absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
-  instruction->AppendOperand(lhs);
-  instruction->AppendOperand(rhs);
-  instruction->dot_dimension_numbers_ =
-      absl::make_unique<DotDimensionNumbers>(dimension_numbers);
-  return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
-    const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
-  CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
-  CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
-
-  auto instruction =
-      absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
-  instruction->AppendOperand(lhs);
-  instruction->AppendOperand(rhs);
-  instruction->dot_dimension_numbers_ =
-      absl::make_unique<DotDimensionNumbers>();
-  instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
-  instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
-  return instruction;
+    const DotDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config) {
+  return absl::make_unique<HloDotInstruction>(
+      shape, lhs, rhs, dimension_numbers, precision_config);
 }
 
 /* static */ std::unique_ptr<HloInstruction>
@@ -1057,7 +1061,6 @@
     derived_instruction->clear_sharding();
   }
   derived_instruction->set_metadata(metadata_);
-  derived_instruction->set_precision_config(precision_config_);
 }
 
 bool HloInstruction::HasSideEffectNoRecurse() const {
@@ -1142,12 +1145,9 @@
     const Shape& shape, HloInstruction* operand,
     std::unique_ptr<DomainMetadata> operand_side_metadata,
     std::unique_ptr<DomainMetadata> user_side_metadata) {
-  auto instruction =
-      absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
-  instruction->operand_side_metadata_ = std::move(operand_side_metadata);
-  instruction->user_side_metadata_ = std::move(user_side_metadata);
-  instruction->AppendOperand(operand);
-  return instruction;
+  return absl::make_unique<HloDomainInstruction>(
+      shape, operand, std::move(operand_side_metadata),
+      std::move(user_side_metadata));
 }
 
 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
@@ -1203,6 +1203,8 @@
     case HloOpcode::kGather:
     case HloOpcode::kScatter:
     case HloOpcode::kIota:
+    case HloOpcode::kDot:
+    case HloOpcode::kDomain:
       clone = CloneWithNewOperandsImpl(shape, new_operands, context);
       break;
     // Unary ops.
@@ -1275,11 +1277,6 @@
       CHECK_EQ(new_operands.size(), 1);
       clone = CreateBitcastConvert(shape, new_operands[0]);
       break;
-    case HloOpcode::kDot:
-      CHECK_EQ(new_operands.size(), 2);
-      clone = CreateDot(shape, new_operands[0], new_operands[1],
-                        *dot_dimension_numbers_);
-      break;
     case HloOpcode::kReshape:
       CHECK_EQ(new_operands.size(), 1);
       clone = CreateReshape(shape, new_operands[0]);
@@ -1304,12 +1301,6 @@
                                 true_computation(), new_operands[2],
                                 false_computation());
       break;
-    case HloOpcode::kDomain:
-      CHECK_EQ(new_operands.size(), 1);
-      clone =
-          CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
-                       user_side_metadata_->Clone());
-      break;
     case HloOpcode::kAfterAll:
       if (new_operands.empty()) {
         clone = CreateToken();
@@ -1605,11 +1596,6 @@
     case HloOpcode::kAfterAll:
       return false;
 
-    // Check dot dimension numbers.
-    case HloOpcode::kDot:
-      return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
-                                           other.dot_dimension_numbers());
-
     // Remaining instructions with special values.
     case HloOpcode::kCall:
       return eq_computations(to_apply(), other.to_apply());
@@ -1625,10 +1611,6 @@
       return false;
     }
 
-    case HloOpcode::kDomain:
-      return operand_side_metadata().Matches(other.operand_side_metadata()) &&
-             user_side_metadata().Matches(other.user_side_metadata());
-
     // Ops migrated to subclasses should never come to this line.
     // TODO(b/80131774): Remove this switch when migration is complete.
     case HloOpcode::kBatchNormTraining:
@@ -1668,6 +1650,8 @@
     case HloOpcode::kDynamicSlice:
     case HloOpcode::kGather:
     case HloOpcode::kScatter:
+    case HloOpcode::kDot:
+    case HloOpcode::kDomain:
       LOG(FATAL) << "Base class impl called for opcode with subclass: "
                  << opcode();
   }
@@ -2037,15 +2021,6 @@
     const HloPrintOptions& options) const {
   std::vector<string> extra = ExtraAttributesToStringImpl(options);
 
-  if (dot_dimension_numbers_ != nullptr) {
-    extra.push_back(DotDimensionNumbersToString());
-  }
-
-  string precision_config_string = PrecisionConfigToString();
-  if (!precision_config_string.empty()) {
-    extra.push_back(precision_config_string);
-  }
-
   if (options.print_subcomputation_mode() ==
       HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
     if (opcode() == HloOpcode::kWhile) {
@@ -2122,7 +2097,7 @@
   if (has_sharding()) {
     extra.push_back(StrCat("sharding=", sharding().ToString()));
   }
-  if (!control_predecessors_.empty()) {
+  if (options.print_control_dependencies() && !control_predecessors_.empty()) {
     extra.push_back(StrCat("control-predecessors={",
                            StrJoin(control_predecessors_, ", ",
                                    [&](string* out, HloInstruction* pre) {
@@ -2131,11 +2106,6 @@
                                    }),
                            "}"));
   }
-  if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
-    extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
-                           "\", entry=", user_side_metadata_->ToString(),
-                           ", exit=", operand_side_metadata_->ToString(), "}"));
-  }
 
   return extra;
 }
@@ -2167,17 +2137,12 @@
 
   *proto.mutable_metadata() = metadata_;
   proto.set_backend_config(backend_config_);
-  *proto.mutable_precision_config() = precision_config_;
   if (opcode() != HloOpcode::kFusion) {
     for (const HloComputation* computation : called_computations_) {
       proto.add_called_computation_ids(computation->unique_id());
     }
   }
 
-  if (dot_dimension_numbers_ != nullptr) {
-    *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
-  }
-
   if (has_sharding()) {
     *proto.mutable_sharding() = sharding().ToProto();
   }
@@ -2871,8 +2836,8 @@
   return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
 }
 
-string PrecisionToString(const PrecisionConfigProto::Precision& precision) {
-  return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision));
+string PrecisionToString(const PrecisionConfig::Precision& precision) {
+  return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
 }
 
 string ConvolutionDimensionNumbersToString(
@@ -2904,31 +2869,6 @@
                 StrJoin(output_dims, ""));
 }
 
-string HloInstruction::DotDimensionNumbersToString() const {
-  std::vector<string> result;
-  if (dot_dimension_numbers_ == nullptr) {
-    return "";
-  }
-  const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
-  if (!dnums.lhs_batch_dimensions().empty()) {
-    result.push_back(StrCat("lhs_batch_dims={",
-                            StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
-  }
-  result.push_back(StrCat("lhs_contracting_dims={",
-                          StrJoin(dnums.lhs_contracting_dimensions(), ","),
-                          "}"));
-
-  if (!dnums.rhs_batch_dimensions().empty()) {
-    result.push_back(StrCat("rhs_batch_dims={",
-                            StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
-  }
-  result.push_back(StrCat("rhs_contracting_dims={",
-                          StrJoin(dnums.rhs_contracting_dimensions(), ","),
-                          "}"));
-
-  return StrJoin(result, ", ");
-}
-
 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
   static std::unordered_map<string, RandomDistribution>* map = [] {
     static auto* map = new std::unordered_map<string, RandomDistribution>;
@@ -2947,31 +2887,13 @@
   return found->second;
 }
 
-string HloInstruction::PrecisionConfigToString() const {
-  if (precision_config_.operand_precision().empty()) {
-    return "";
-  }
-  return StrCat(
-      "operand_precision={",
-      StrJoin(precision_config_.operand_precision(), ",",
-              [](string* out, int32 precision) {
-                CHECK(PrecisionConfigProto::Precision_IsValid(precision))
-                    << precision;
-                StrAppend(out, PrecisionToString(
-                                   static_cast<PrecisionConfigProto::Precision>(
-                                       precision)));
-              }),
-      "}");
-}
-
-StatusOr<PrecisionConfigProto::Precision> StringToPrecision(
-    const string& name) {
-  static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] {
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
+  static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
     static auto* map =
-        new std::unordered_map<string, PrecisionConfigProto::Precision>;
-    for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) {
-      if (PrecisionConfigProto::Precision_IsValid(i)) {
-        auto value = static_cast<PrecisionConfigProto::Precision>(i);
+        new std::unordered_map<string, PrecisionConfig::Precision>;
+    for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
+      if (PrecisionConfig::Precision_IsValid(i)) {
+        auto value = static_cast<PrecisionConfig::Precision>(i);
         (*map)[PrecisionToString(value)] = value;
       }
     }
@@ -3024,6 +2946,16 @@
   return ret;
 }
 
+const PrecisionConfig& HloInstruction::precision_config() const {
+  if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
+    return convolution->precision_config();
+  }
+  if (auto* dot = DynCast<HloDotInstruction>(this)) {
+    return dot->precision_config();
+  }
+  LOG(FATAL) << "Unimplemented method.";
+}
+
 HloModule* HloInstruction::GetModule() const {
   if (parent_) {
     return parent_->parent();
@@ -3328,4 +3260,15 @@
   return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
 }
 
+const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
+  return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
+}
+
+const DomainMetadata& HloInstruction::operand_side_metadata() const {
+  return Cast<HloDomainInstruction>(this)->operand_side_metadata();
+}
+
+const DomainMetadata& HloInstruction::user_side_metadata() const {
+  return Cast<HloDomainInstruction>(this)->user_side_metadata();
+}
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cca134e..4f6cac1 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -82,6 +82,7 @@
         print_operand_shape_(true),
         print_program_shape_(true),
         print_percent_(true),
+        print_control_dependencies_(true),
         canonicalize_instruction_names_(false),
         indent_amount_(0),
         is_in_nested_computation_(false) {}
@@ -94,7 +95,8 @@
         .set_print_backend_config(false)
         .set_print_operand_shape(false)
         .set_print_program_shape(false)
-        .set_print_percent(false);
+        .set_print_percent(false)
+        .set_print_control_dependencies(false);
   }
 
   // Options to produce the canonical string representing an isomorphic
@@ -108,6 +110,7 @@
         .set_print_operand_shape(true)
         .set_print_program_shape(false)
         .set_print_percent(false)
+        .set_print_control_dependencies(false)
         .set_canonicalize_instruction_names(true);
   }
 
@@ -153,6 +156,12 @@
     return *this;
   }
 
+  // If true, control dependencies will be printed.
+  HloPrintOptions& set_print_control_dependencies(bool value) {
+    print_control_dependencies_ = value;
+    return *this;
+  }
+
   // If true, only a part of operands will be printed out, and their names will
   // be omitted (note that in this case the text will not be parsable).
   HloPrintOptions& set_compact_operands(bool value) {
@@ -190,6 +199,9 @@
   bool print_operand_shape() const { return print_operand_shape_; }
   bool print_program_shape() const { return print_program_shape_; }
   bool print_percent() const { return print_percent_; }
+  bool print_control_dependencies() const {
+    return print_control_dependencies_;
+  }
   bool canonicalize_instruction_names() const {
     return canonicalize_instruction_names_;
   }
@@ -205,6 +217,7 @@
   bool print_operand_shape_;
   bool print_program_shape_;
   bool print_percent_;
+  bool print_control_dependencies_;
   bool canonicalize_instruction_names_;
   int indent_amount_;
   bool is_in_nested_computation_;
@@ -346,8 +359,7 @@
                                                          const string& name);
 
   // Creates a literal constant instruction.
-  static std::unique_ptr<HloInstruction> CreateConstant(
-      std::unique_ptr<Literal> literal);
+  static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
 
   // Creates an Iota instruction.
   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
@@ -405,9 +417,9 @@
   // and window describes how the filter is applied to lhs.
   static std::unique_ptr<HloInstruction> CreateConvolve(
       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-      const Window& window,
+      int64 feature_group_count, const Window& window,
       const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count = 1);
+      const PrecisionConfig& precision_config);
 
   // Creates an FFT op, of the type indicated by fft_type.
   static std::unique_ptr<HloInstruction> CreateFft(
@@ -418,13 +430,8 @@
   // dimensions specified in 'dimension_numbers'.
   static std::unique_ptr<HloInstruction> CreateDot(
       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-      const DotDimensionNumbers& dimension_numbers);
-
-  // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
-  // of the LHS with dimension 0 of the RHS with no batch dimensions.  Both LHS
-  // and the RHS must be of rank 2.
-  static std::unique_ptr<HloInstruction> CreateCanonicalDot(
-      const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
+      const DotDimensionNumbers& dimension_numbers,
+      const PrecisionConfig& precision_config);
 
   // Creates a reduce-precision op, where operand is the data to reduce in
   // precision, and exponent_bits and mantissa_bits describe the precision to
@@ -865,11 +872,6 @@
       return false;
     }
 
-    if (!absl::c_equal(precision_config_.operand_precision(),
-                       other.precision_config_.operand_precision())) {
-      return false;
-    }
-
     return IdenticalSlowPath(other, eq_computations);
   }
 
@@ -1084,15 +1086,6 @@
     return other->has_sharding() ? sharding() == other->sharding() : false;
   }
 
-  // Retrieves the operand side metadata of a kDomain instruction.
-  const DomainMetadata& operand_side_metadata() const {
-    return *operand_side_metadata_;
-  }
-  // Retrieves the user side metadata of a kDomain instruction.
-  const DomainMetadata& user_side_metadata() const {
-    return *user_side_metadata_;
-  }
-
   // When creating a new instruction which either replaces, or shifts up (kCopy
   // insertion case), another instruction, we need to make sure the certain
   // properties of the new instruction are copied into the derived one. As of
@@ -1100,18 +1093,6 @@
   // instruction.
   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
 
-  // Returns data on the dimension numbers used for a dot operation.
-  const DotDimensionNumbers& dot_dimension_numbers() const {
-    CHECK(dot_dimension_numbers_ != nullptr);
-    return *dot_dimension_numbers_;
-  }
-
-  // Returns the dump string of the dot dimension numbers.
-  string DotDimensionNumbersToString() const;
-
-  // Returns the dump string of the precision configuration.
-  string PrecisionConfigToString() const;
-
   // Clones the HLO instruction. The clone will have the same opcode, shape, and
   // operands. After creation the clone has no uses. "this" (the instruction
   // cloned from) is not changed. Suffix is the string to append to the name of
@@ -1261,12 +1242,8 @@
   // information. Transformations to other HLOs will not preserve this
   // information but it is presumed that the alternate lowering is strictly
   // superior.
-  const PrecisionConfigProto& precision_config() const {
-    return precision_config_;
-  }
-  void set_precision_config(const PrecisionConfigProto& precision_config) {
-    precision_config_ = precision_config;
-  }
+  // Precondition: opcode must be kConvolution or kDot.
+  const PrecisionConfig& precision_config() const;
 
   // Sets the debug metadata for this instruction.
   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1509,6 +1486,15 @@
   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
 
+  // Delegates to HloDotInstruction::dot_dimension_numbers().
+  const DotDimensionNumbers& dot_dimension_numbers() const;
+
+  // Delegates to HloDomainInstruction::operand_side_metadata().
+  const DomainMetadata& operand_side_metadata() const;
+
+  // Delegates to HloDomainInstruction::user_side_metadata().
+  const DomainMetadata& user_side_metadata() const;
+
   // Old methods kept for smooth subclassing transition END.
 
  protected:
@@ -1648,22 +1634,12 @@
   // Result shape of this instruction.
   Shape shape_;
 
-  // Describes the dimension numbers used for a dot.
-  std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
-
-  // Used to tag kCopy instructions that are eligible for copy elision.
-  bool copy_elision_allowed_ = true;
-
   // The sharding, if one exists.
   // Uses std::shared_ptr to allow reuse of the same sharding object between
   // HloInstructions and other components as HloSharding can be very large for
   // many element tuples.
   std::shared_ptr<const HloSharding> sharding_;
 
-  // Fields used by the kDomain instruction.
-  std::unique_ptr<DomainMetadata> operand_side_metadata_;
-  std::unique_ptr<DomainMetadata> user_side_metadata_;
-
   // Computations called by this instruction.
   std::vector<HloComputation*> called_computations_;
 
@@ -1677,10 +1653,6 @@
   // HLO. See the documentation on backend_config().
   string backend_config_;
 
-  // Information used to communicate to the implementation about the algorithm
-  // used to produce results. See the documentation on precision_config().
-  PrecisionConfigProto precision_config_;
-
   // String identifier for instruction.
   string name_;
 
@@ -1703,12 +1675,12 @@
 string PaddingConfigToString(const PaddingConfig& padding);
 string OpMetadataToString(const OpMetadata& metadata);
 string RandomDistributionToString(const RandomDistribution& distribution);
-string PrecisionToString(const PrecisionConfigProto::Precision& precision);
+string PrecisionToString(const PrecisionConfig::Precision& precision);
 string ConvolutionDimensionNumbersToString(
     const ConvolutionDimensionNumbers& dnums);
 
 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
-StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name);
+StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
 
 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
 
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 76b0e94..c1b7c38 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1147,8 +1147,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
@@ -1188,8 +1188,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
@@ -1239,8 +1239,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
+  auto dot = builder.AddInstruction(HloInstruction::CreateDot(
+      data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
   auto one = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
   auto add_operand = builder.AddInstruction(
@@ -1320,8 +1320,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto options = HloPrintOptions().set_print_metadata(false);
 
@@ -1485,8 +1485,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto options = HloPrintOptions().Canonical();
 
@@ -1527,8 +1527,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
@@ -1583,8 +1583,8 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
-  HloInstruction* dot = builder.AddInstruction(
-      HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
+  HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+      sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule();
   auto* computation = module->AddEntryComputation(builder.Build());
@@ -1752,9 +1752,9 @@
   auto* conv = module->entry_computation()->root_instruction();
 
   auto clone = conv->Clone();
-  EXPECT_THAT(clone->precision_config().operand_precision(),
-              ::testing::ElementsAre(PrecisionConfigProto::HIGH,
-                                     PrecisionConfigProto::DEFAULT));
+  EXPECT_THAT(
+      clone->precision_config().operand_precision(),
+      ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e46afa7..e92882c 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -47,6 +47,27 @@
         return instruction->IsElementwiseOnOperand(operand_index);
       });
 }
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+  if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+        return static_cast<PrecisionConfig::Precision>(precision) ==
+               PrecisionConfig::DEFAULT;
+      })) {
+    return "";
+  }
+
+  return StrCat(
+      "operand_precision={",
+      StrJoin(
+          precision_config.operand_precision(), ",",
+          [](string* out, int32 precision) {
+            CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+            StrAppend(out,
+                      PrecisionToString(
+                          static_cast<PrecisionConfig::Precision>(precision)));
+          }),
+      "}");
+}
 }  // namespace
 
 HloBatchNormInstruction::HloBatchNormInstruction(
@@ -824,8 +845,8 @@
       shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_);
 }
 
-HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
-    : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+HloConstantInstruction::HloConstantInstruction(Literal literal)
+    : HloInstruction(HloOpcode::kConstant, literal.shape()),
       literal_(std::move(literal)) {}
 
 HloConstantInstruction::HloConstantInstruction(const Shape& shape)
@@ -833,7 +854,7 @@
 
 HloInstructionProto HloConstantInstruction::ToProto() const {
   HloInstructionProto proto = HloInstruction::ToProto();
-  if (literal_ != nullptr) {
+  if (literal_.has_value()) {
     *proto.mutable_literal() = literal_->ToProto();
   }
   return proto;
@@ -855,7 +876,7 @@
 
   if (!mutable_array_subshape->has_layout() ||
       !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
-    literal_ = literal_->Relayout(new_layout, shape_index);
+    *literal_ = literal_->Relayout(new_layout, shape_index);
     *mutable_array_subshape->mutable_layout() = new_layout;
   }
 }
@@ -872,7 +893,8 @@
 HloConstantInstruction::CloneWithNewOperandsImpl(
     const Shape& shape, absl::Span<HloInstruction* const> new_operands,
     HloCloneContext* context) const {
-  return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique());
+  CHECK(literal_.has_value());
+  return absl::make_unique<HloConstantInstruction>(literal_->Clone());
 }
 
 string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
@@ -880,7 +902,7 @@
     CanonicalNameMap* canonical_name_map) const {
   string operands;
   // For constants, show the actual value in place of an empty operand list.
-  if (literal_ != nullptr &&
+  if (literal_.has_value() &&
       ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
        options.print_large_constants())) {
     // Literal::ToString emits multidimensional arrays over multiple
@@ -915,7 +937,7 @@
 
 HloInstructionProto HloTraceInstruction::ToProto() const {
   HloInstructionProto proto = HloInstruction::ToProto();
-  *proto.mutable_literal() = literal_->ToProto();
+  *proto.mutable_literal() = literal_.ToProto();
   return proto;
 }
 
@@ -1628,12 +1650,14 @@
 
 HloConvolutionInstruction::HloConvolutionInstruction(
     const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-    const Window& window, const ConvolutionDimensionNumbers& dimension_numbers,
-    int64 feature_group_count)
+    int64 feature_group_count, const Window& window,
+    const ConvolutionDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config)
     : HloInstruction(HloOpcode::kConvolution, shape),
+      feature_group_count_(feature_group_count),
       window_(window),
       convolution_dimension_numbers_(dimension_numbers),
-      feature_group_count_(feature_group_count) {
+      precision_config_(precision_config) {
   if (window_util::HasBaseDilation(window)) {
     SetAndSanitizeName(StrCat(name(), "-base-dilated"));
   }
@@ -1661,6 +1685,7 @@
   *proto.mutable_convolution_dimension_numbers() =
       convolution_dimension_numbers_;
   proto.set_feature_group_count(feature_group_count_);
+  *proto.mutable_precision_config() = precision_config_;
   return proto;
 }
 
@@ -1672,7 +1697,15 @@
   }
   extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
                                             convolution_dimension_numbers_)));
-  extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+  if (feature_group_count_ != 1) {
+    extra.push_back(StrCat("feature_group_count=", feature_group_count_));
+  }
+
+  string precision_config_string = PrecisionConfigToString(precision_config_);
+  if (!precision_config_string.empty()) {
+    extra.push_back(precision_config_string);
+  }
+
   return extra;
 }
 
@@ -1688,7 +1721,9 @@
   return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
          protobuf_util::ProtobufEquals(
              convolution_dimension_numbers(),
-             casted_other.convolution_dimension_numbers());
+             casted_other.convolution_dimension_numbers()) &&
+         protobuf_util::ProtobufEquals(precision_config(),
+                                       casted_other.precision_config());
 }
 
 std::unique_ptr<HloInstruction>
@@ -1697,8 +1732,8 @@
     HloCloneContext* context) const {
   CHECK_EQ(new_operands.size(), 2);
   return absl::make_unique<HloConvolutionInstruction>(
-      shape, new_operands[0], new_operands[1], window(),
-      convolution_dimension_numbers_, feature_group_count_);
+      shape, new_operands[0], new_operands[1], feature_group_count_, window(),
+      convolution_dimension_numbers_, precision_config_);
 }
 
 HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -2157,4 +2192,113 @@
   return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
 }
 
+HloDotInstruction::HloDotInstruction(
+    const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+    const DotDimensionNumbers& dimension_numbers,
+    const PrecisionConfig& precision_config)
+    : HloInstruction(HloOpcode::kDot, shape),
+      dot_dimension_numbers_(dimension_numbers),
+      precision_config_(precision_config) {
+  AppendOperand(lhs);
+  AppendOperand(rhs);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+  HloInstructionProto proto = HloInstruction::ToProto();
+  *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+  *proto.mutable_precision_config() = precision_config_;
+  return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+    const HloPrintOptions& options) const {
+  std::vector<string> extra = {DotDimensionNumbersToString()};
+
+  string precision_config_string = PrecisionConfigToString(precision_config_);
+  if (!precision_config_string.empty()) {
+    extra.push_back(precision_config_string);
+  }
+  return extra;
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+    const HloInstruction& other,
+    const std::function<bool(const HloComputation*, const HloComputation*)>&
+        eq_computations) const {
+  const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+  return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+                                       casted_other.dot_dimension_numbers()) &&
+         protobuf_util::ProtobufEquals(precision_config(),
+                                       casted_other.precision_config());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+    const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+    HloCloneContext* context) const {
+  CHECK_EQ(new_operands.size(), 2);
+  return absl::make_unique<HloDotInstruction>(
+      shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+      precision_config_);
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+  std::vector<string> result;
+  const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+  if (!dnums.lhs_batch_dimensions().empty()) {
+    result.push_back(StrCat("lhs_batch_dims={",
+                            StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+  }
+  result.push_back(StrCat("lhs_contracting_dims={",
+                          StrJoin(dnums.lhs_contracting_dimensions(), ","),
+                          "}"));
+
+  if (!dnums.rhs_batch_dimensions().empty()) {
+    result.push_back(StrCat("rhs_batch_dims={",
+                            StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+  }
+  result.push_back(StrCat("rhs_contracting_dims={",
+                          StrJoin(dnums.rhs_contracting_dimensions(), ","),
+                          "}"));
+
+  return StrJoin(result, ", ");
+}
+
+HloDomainInstruction::HloDomainInstruction(
+    const Shape& shape, HloInstruction* operand,
+    std::unique_ptr<DomainMetadata> operand_side_metadata,
+    std::unique_ptr<DomainMetadata> user_side_metadata)
+    : HloInstruction(HloOpcode::kDomain, shape),
+      operand_side_metadata_(std::move(operand_side_metadata)),
+      user_side_metadata_(std::move(user_side_metadata)) {
+  AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+    const HloPrintOptions& options) const {
+  if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+    return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+                   "\", entry=", user_side_metadata_->ToString(),
+                   ", exit=", operand_side_metadata_->ToString(), "}")};
+  }
+  return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+    const HloInstruction& other,
+    const std::function<bool(const HloComputation*, const HloComputation*)>&
+        eq_computations) const {
+  const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+  return operand_side_metadata().Matches(
+             casted_other.operand_side_metadata()) &&
+         user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+    const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+    HloCloneContext* context) const {
+  CHECK_EQ(new_operands.size(), 1);
+  return absl::make_unique<HloDomainInstruction>(
+      shape, new_operands[0], operand_side_metadata_->Clone(),
+      user_side_metadata_->Clone());
+}
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 3230383..2d7bc83 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -580,13 +580,13 @@
 
 class HloConstantInstruction : public HloInstruction {
  public:
-  explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+  explicit HloConstantInstruction(Literal literal);
   // Used when the literal is too large and dropped.
   explicit HloConstantInstruction(const Shape& shape);
   // Returns the literal associated with this instruction.
   const Literal& literal() const { return *literal_; }
   // Returns whether there is literal associated with this instruction.
-  bool HasLiteral() const { return literal_ != nullptr; }
+  bool HasLiteral() const { return literal_.has_value(); }
   // Returns a serialized representation of this instruction.
   HloInstructionProto ToProto() const override;
 
@@ -610,15 +610,14 @@
   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
       HloCloneContext* context) const override;
-  // TODO(b/36360764): Remove unique_ptr wrapping.
-  std::unique_ptr<Literal> literal_;
+  absl::optional<Literal> literal_;
 };
 
 class HloTraceInstruction : public HloInstruction {
  public:
   explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
   // Returns a tag to be used in tracing.
-  string TracingTag() const { return literal_->GetR1U8AsString(); }
+  string TracingTag() const { return literal_.GetR1U8AsString(); }
   // Returns a serialized representation of this instruction.
   HloInstructionProto ToProto() const override;
 
@@ -631,8 +630,7 @@
   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
       HloCloneContext* context) const override;
-  // TODO(b/36360764): Remove unique_ptr wrapping.
-  std::unique_ptr<Literal> literal_;
+  Literal literal_;
 };
 
 class HloFusionInstruction : public HloInstruction {
@@ -942,9 +940,9 @@
  public:
   explicit HloConvolutionInstruction(
       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
-      const Window& window,
+      int64 feature_group_count, const Window& window,
       const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count);
+      const PrecisionConfig& precision_config);
   const Window& window() const override { return window_; }
   void set_window(const Window& window) override { window_ = window; }
   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
@@ -957,6 +955,16 @@
   // The number of feature groups. Must be a divisor of the input feature
   // dimension and output feature dimension.
   int64 feature_group_count() const { return feature_group_count_; }
+
+  // Returns the information used to tell the implementation information about
+  // what sort of precision is requested. The meaning of the field is backend
+  // specific. At the moment, it is only supported for kConvolution and kDot.
+  // Transformations on one kDot or kConvolution to another will preserve this
+  // information. Transformations to other HLOs will not preserve this
+  // information but it is presumed that the alternate lowering is strictly
+  // superior.
+  const PrecisionConfig& precision_config() const { return precision_config_; }
+
   string ToCategory() const override;
   // Returns a serialized representation of this instruction.
   HloInstructionProto ToProto() const override;
@@ -972,12 +980,16 @@
   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
       HloCloneContext* context) const override;
-  Window window_;
-  // Describes the dimension numbers used for a convolution.
-  ConvolutionDimensionNumbers convolution_dimension_numbers_;
   // The number of feature groups. Must be a divisor of the input feature
   // dimension and output feature dimension.
   int64 feature_group_count_;
+  // Describes the window used for a convolution.
+  Window window_;
+  // Describes the dimension numbers used for a convolution.
+  ConvolutionDimensionNumbers convolution_dimension_numbers_;
+  // Information used to communicate to the implementation about the algorithm
+  // used to produce results. See the documentation on precision_config().
+  PrecisionConfig precision_config_;
 };
 
 class HloReduceWindowInstruction : public HloInstruction {
@@ -1270,6 +1282,85 @@
   const int64 iota_dimension_;
 };
 
+class HloDotInstruction : public HloInstruction {
+ public:
+  // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+  // dimensions specified in 'dimension_numbers'.
+  explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
+                             HloInstruction* rhs,
+                             const DotDimensionNumbers& dimension_numbers,
+                             const PrecisionConfig& precision_config);
+
+  // Returns data on the dimension numbers used for a dot operation.
+  const DotDimensionNumbers& dot_dimension_numbers() const {
+    return dot_dimension_numbers_;
+  }
+
+  // Returns the information used to tell the implementation information about
+  // what sort of precision is requested. The meaning of the field is backend
+  // specific. At the moment, it is only supported for kConvolution and kDot.
+  // Transformations on one kDot or kConvolution to another will preserve this
+  // information. Transformations to other HLOs will not preserve this
+  // information but it is presumed that the alternate lowering is strictly
+  // superior.
+  const PrecisionConfig& precision_config() const { return precision_config_; }
+
+  // Returns a serialized representation of this instruction.
+  HloInstructionProto ToProto() const override;
+
+ private:
+  std::vector<string> ExtraAttributesToStringImpl(
+      const HloPrintOptions& options) const override;
+  bool IdenticalSlowPath(
+      const HloInstruction& other,
+      const std::function<bool(const HloComputation*, const HloComputation*)>&
+          eq_computations) const override;
+  // Implementation for non-common logic of CloneWithNewOperands.
+  std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+      HloCloneContext* context) const override;
+  // Returns the dump string of the dot dimension numbers.
+  string DotDimensionNumbersToString() const;
+
+  // Describes the dimension numbers used for a dot.
+  DotDimensionNumbers dot_dimension_numbers_;
+
+  // Information used to communicate to the implementation about the algorithm
+  // used to produce results. See the documentation on precision_config().
+  PrecisionConfig precision_config_;
+};
+
+class HloDomainInstruction : public HloInstruction {
+ public:
+  explicit HloDomainInstruction(
+      const Shape& shape, HloInstruction* operand,
+      std::unique_ptr<DomainMetadata> operand_side_metadata,
+      std::unique_ptr<DomainMetadata> user_side_metadata);
+
+  // Retrieves the operand side metadata of a kDomain instruction.
+  const DomainMetadata& operand_side_metadata() const {
+    return *operand_side_metadata_;
+  }
+  // Retrieves the user side metadata of a kDomain instruction.
+  const DomainMetadata& user_side_metadata() const {
+    return *user_side_metadata_;
+  }
+
+ private:
+  std::vector<string> ExtraAttributesToStringImpl(
+      const HloPrintOptions& options) const override;
+  bool IdenticalSlowPath(
+      const HloInstruction& other,
+      const std::function<bool(const HloComputation*, const HloComputation*)>&
+          eq_computations) const override;
+  // Implementation for non-common logic of CloneWithNewOperands.
+  std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+      const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+      HloCloneContext* context) const override;
+
+  std::unique_ptr<DomainMetadata> operand_side_metadata_;
+  std::unique_ptr<DomainMetadata> user_side_metadata_;
+};
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
similarity index 71%
rename from tensorflow/compiler/xla/service/hlo_scheduling.cc
rename to tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 0fc3b26..c7ec88d 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 
 #include <map>
 #include <queue>
@@ -70,7 +70,7 @@
  public:
   // Construct and return a memory-minimizing sequence of HLO instructions
   // containing the given HLO computation.
-  static StatusOr<std::vector<const HloInstruction*>> Run(
+  static StatusOr<HloInstructionSequence> Run(
       const HloComputation& computation,
       const TuplePointsToAnalysis& points_to_analysis,
       const LogicalBuffer::SizeFunction& size_function,
@@ -229,8 +229,8 @@
     return {BytesFreedIfScheduled(entry), entry.instruction->user_count()};
   }
 
-  std::vector<const HloInstruction*> CreateSchedule() {
-    std::vector<const HloInstruction*> schedule;
+  HloInstructionSequence CreateSchedule() {
+    HloInstructionSequence schedule;
 
     // Populate the ready list with instructions which have no operands or
     // control predecessors.
@@ -374,7 +374,7 @@
   return size;
 }
 
-StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
+StatusOr<HloInstructionSequence> ScheduleComputationHelper(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
@@ -392,7 +392,7 @@
 
 }  // namespace
 
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
@@ -443,7 +443,7 @@
   // Construct a total order based on DFS post-order, visiting operands in
   // decreasing cumulative extra user order, and next by cumulative size, with a
   // tiebreaker by name for determinism.
-  std::vector<const HloInstruction*> sequence;
+  HloInstructionSequence sequence;
   FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
     sequence.push_back(hlo);
     return Status::OK();
@@ -463,7 +463,7 @@
   return sequence;
 }  // namespace xla
 
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
@@ -473,18 +473,16 @@
                             memory_by_computation);
 }
 
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
     const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
         memory_by_computation) {
-  const auto& post_order = computation.MakeInstructionPostOrder();
-  return std::vector<const HloInstruction*>{post_order.begin(),
-                                            post_order.end()};
+  return HloInstructionSequence(computation.MakeInstructionPostOrder());
 }
 
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
     const HloComputation& computation,
     const TuplePointsToAnalysis& points_to_analysis,
     const LogicalBuffer::SizeFunction& size_function,
@@ -499,7 +497,7 @@
   // List wins for most of our benchmarks; postorder-based schedulers win for
   // some RNNs.
   TF_ASSIGN_OR_RETURN(
-      std::vector<const HloInstruction*> list_sequence,
+      HloInstructionSequence list_sequence,
       ListMemoryScheduler(computation, points_to_analysis, size_function,
                           memory_by_computation));
   TF_ASSIGN_OR_RETURN(const int64 list_memory,
@@ -508,7 +506,7 @@
                           size_function, &memory_by_computation));
   VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
 
-  TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
+  TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence,
                       DFSMemoryScheduler(computation, points_to_analysis,
                                          size_function, memory_by_computation));
   TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
@@ -518,7 +516,7 @@
   VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
 
   TF_ASSIGN_OR_RETURN(
-      std::vector<const HloInstruction*> post_order_sequence,
+      HloInstructionSequence post_order_sequence,
       PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
                                memory_by_computation));
   TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
@@ -545,32 +543,35 @@
   }
 }
 
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+StatusOr<HloSchedule> ScheduleModule(
     const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
     const MemorySchedulerAlgorithm& algorithm) {
-  SequentialHloOrdering::HloModuleSequence sequence;
+  HloSchedule schedule(&module);
   TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
                       TuplePointsToAnalysis::Run(&module));
   tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
   for (const auto* computation : module.MakeComputationPostOrder()) {
     if (!computation->IsFusionComputation()) {
-      TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
+      TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
                           ScheduleComputationHelper(
                               *computation, *points_to_analysis, size_function,
                               algorithm, memory_by_computation));
       memory_by_computation[computation] =
           HeapSimulator::MinimumMemoryForComputation(
-              *computation, one_computation_sequence, *points_to_analysis,
+              *computation, computation_sequence, *points_to_analysis,
               size_function, &memory_by_computation)
               .ValueOrDie();
-      sequence[computation] = std::move(one_computation_sequence);
+      schedule.set_sequence(computation, std::move(computation_sequence));
     }
   }
-  VLOG(1) << "Module schedule:\n" << sequence;
-  return sequence;
+  VLOG(1) << "Module schedule:\n" << schedule;
+
+  TF_RETURN_IF_ERROR(schedule.Verify());
+
+  return std::move(schedule);
 }
 
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
+StatusOr<HloInstructionSequence> ScheduleComputation(
     const HloComputation& computation,
     const LogicalBuffer::SizeFunction& size_function) {
   CHECK(!computation.IsFusionComputation());
@@ -581,187 +582,22 @@
                                    size_function, nullptr, empty_map);
 }
 
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) {
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence;
-  for (const auto& computation_sequence : sequence) {
-    for (const HloInstruction* instruction : computation_sequence.second) {
-      id_sequence[computation_sequence.first].push_back(
-          instruction->unique_id());
-    }
-  }
-  return id_sequence;
+HloMemoryScheduler::HloMemoryScheduler(
+    const LogicalBuffer::SizeFunction& size_function,
+    const MemorySchedulerAlgorithm& algorithm)
+    : size_function_(size_function), algorithm_(algorithm) {}
+
+StatusOr<bool> HloMemoryScheduler::Run(HloModule* module) {
+  TF_ASSIGN_OR_RETURN(HloSchedule schedule,
+                      ScheduleModule(*module, size_function_, algorithm_));
+  TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+  return true;
 }
 
-Status UpdateSchedule(
-    const HloModule& module,
-    const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
-        id_sequence,
-    SequentialHloOrdering::HloModuleSequence* sequence) {
-  // Map from unique ID to HloInstruction pointer for instructions in the
-  // module.
-  tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
-  // Set of all HloInstructions in the schedule.
-  tensorflow::gtl::FlatSet<int> ids_in_schedule;
-  std::vector<HloComputation*> nonfusion_computations =
-      module.MakeNonfusionComputations();
-  for (const HloComputation* computation : nonfusion_computations) {
-    for (const HloInstruction* instruction : computation->instructions()) {
-      TF_RET_CHECK(
-          id_to_instruction.insert({instruction->unique_id(), instruction})
-              .second);
-    }
-    for (int id : id_sequence.at(computation)) {
-      ids_in_schedule.insert(id);
-    }
-  }
-
-  // Map from HloInstruction X to newly added instructions (instruction is in
-  // module, but not in schedule) which use X. If an instruction is not in the
-  // map, then it has no users which are newly added instructions.
-  tensorflow::gtl::FlatMap<const HloInstruction*,
-                           std::vector<const HloInstruction*>>
-      new_instruction_uses;
-
-  // For each newly added instruction, this is the count of the instruction's
-  // operands that have not yet been scheduled. When this value reaches zero,
-  // then the instruction may be placed in the schedule.
-  tensorflow::gtl::FlatMap<const HloInstruction*, int>
-      unscheduled_operand_count;
-  // For each computation, this is the set of newly added instructions which
-  // have no operands. These must be handled specially and are added to the
-  // beginning of the schedule.
-  tensorflow::gtl::FlatMap<const HloComputation*,
-                           std::vector<const HloInstruction*>>
-      new_zero_operand_instructions;
-  for (const HloComputation* computation : nonfusion_computations) {
-    new_zero_operand_instructions[computation] = {};
-    for (const HloInstruction* instruction : computation->instructions()) {
-      if (ids_in_schedule.count(instruction->unique_id()) == 0) {
-        // This is a newly added instruction which is not in the schedule.
-        for (const HloInstruction* operand : instruction->operands()) {
-          new_instruction_uses[operand].push_back(instruction);
-        }
-        if (instruction->operands().empty()) {
-          new_zero_operand_instructions[computation].push_back(instruction);
-        }
-        unscheduled_operand_count[instruction] = instruction->operand_count();
-      }
-    }
-  }
-
-  // Update the schedule with the newly added instructions, and remove any
-  // instructions no longer in the graph.
-  for (const HloComputation* computation : nonfusion_computations) {
-    std::vector<const HloInstruction*> old_computation_sequence =
-        std::move(sequence->at(computation));
-    sequence->at(computation).clear();
-
-    // Create a worklist of newly added instructions which are ready to be added
-    // to the schedule. Initialize worklist with those that have zero operands.
-    std::queue<const HloInstruction*> worklist;
-    for (const HloInstruction* instruction :
-         new_zero_operand_instructions.at(computation)) {
-      worklist.push(instruction);
-    }
-
-    // Lambda which schedules all instructions on the worklist.
-    auto schedule_worklist = [&]() {
-      while (!worklist.empty()) {
-        const HloInstruction* instruction = worklist.front();
-        worklist.pop();
-        sequence->at(computation).push_back(instruction);
-        std::vector<const HloInstruction*>* new_users =
-            tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
-        if (new_users != nullptr) {
-          // This just-scheduled instruction has users which are newly added to
-          // the module. Update the number of unscheduled operands and push the
-          // newly added instruction to the worklist if it is ready to
-          // schedule.
-          for (const HloInstruction* new_user : *new_users) {
-            unscheduled_operand_count.at(new_user)--;
-            CHECK_GE(unscheduled_operand_count.at(new_user), 0);
-            if (unscheduled_operand_count.at(new_user) == 0) {
-              worklist.push(new_user);
-            }
-          }
-        }
-      }
-    };
-
-    schedule_worklist();
-    for (int id : id_sequence.at(computation)) {
-      auto it = id_to_instruction.find(id);
-      if (it == id_to_instruction.end()) {
-        // This instruction in the schedule is no longer in the module.
-        continue;
-      }
-      const HloInstruction* instruction = it->second;
-      worklist.push(instruction);
-      schedule_worklist();
-    }
-  }
-
-  TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence));
-  return Status::OK();
-}
-
-Status VerifySchedule(
-    const HloModule& module,
-    const SequentialHloOrdering::HloModuleSequence& sequence) {
-  VLOG(2) << "VerifySchedule()";
-  XLA_VLOG_LINES(2, module.ToString());
-  VLOG(2) << sequence;
-
-  // Verify the set of computations in the sequence is exactly the set of
-  // computations in the module.
-  std::vector<HloComputation*> nonfusion_computations =
-      module.MakeNonfusionComputations();
-  TF_RET_CHECK(nonfusion_computations.size() == sequence.size());
-  tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module(
-      module.computations().begin(), module.computations().end());
-  for (const auto& computation_sequence : sequence) {
-    TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1);
-  }
-
-  // For each computation verify the set of instructions is the same and that
-  // each dependency and control edge is honored.
-  for (const HloComputation* computation : nonfusion_computations) {
-    tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
-    int pos = 0;
-    for (const HloInstruction* instruction : sequence.at(computation)) {
-      TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
-          << "Instruction " << instruction->name()
-          << " appears more than once in the schedule";
-      pos++;
-    }
-
-    TF_RET_CHECK(instruction_position.size() ==
-                 computation->instruction_count());
-    for (const HloInstruction* instruction : computation->instructions()) {
-      TF_RET_CHECK(instruction_position.count(instruction) == 1)
-          << "Instruction " << instruction->name() << " is not in schedule";
-    }
-
-    for (const HloInstruction* instruction : computation->instructions()) {
-      for (const HloInstruction* operand : instruction->operands()) {
-        TF_RET_CHECK(instruction_position.at(operand) <
-                     instruction_position.at(instruction))
-            << "Instruction " << instruction->name()
-            << " is not scheduled after its operand " << operand->name();
-      }
-
-      for (const HloInstruction* pred : instruction->control_predecessors()) {
-        TF_RET_CHECK(instruction_position.at(pred) <
-                     instruction_position.at(instruction))
-            << "Instruction " << instruction->name()
-            << " is not scheduled after its control predecessor "
-            << pred->name();
-      }
-    }
-  }
-
-  return Status::OK();
+StatusOr<bool> HloDescheduler::Run(HloModule* module) {
+  bool changed = module->has_schedule();
+  module->clear_schedule();
+  return changed;
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
new file mode 100644
index 0000000..5e02868
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -0,0 +1,123 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+// A memory scheduler computes an execution sequence for the HLO instructions in
+// 'computation' that minimizes peak memory, given a points-to analysis result
+// that describes buffer aliasing, together with a target-specific size function
+// that maps a tensor's logical size to its padded size.
+typedef std::function<StatusOr<HloInstructionSequence>(
+    const HloComputation&, const TuplePointsToAnalysis&,
+    const LogicalBuffer::SizeFunction&,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
+    MemorySchedulerAlgorithm;
+
+// List scheduler
+StatusOr<HloInstructionSequence> ListMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+        memory_by_computation);
+
+// DFS-order scheduler
+StatusOr<HloInstructionSequence> DFSMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+        memory_by_computation);
+
+// Naive Post Order scheduler
+StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+        memory_by_computation);
+
+// The default scheduling algorithm. Runs both the list scheduler
+// and the DFS scheduler, and chooses whichever returns a lower min-memory,
+// not accounting for fragmentation.
+StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
+    const HloComputation& computation,
+    const TuplePointsToAnalysis& points_to_analysis,
+    const LogicalBuffer::SizeFunction& size_function,
+    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+        memory_by_computation);
+
+// Returns an HloSchedule which seeks to minimize the memory required for
+// the computation. size_function is the function returning the number of bytes
+// required for a LogicalBuffer.
+StatusOr<HloSchedule> ScheduleModule(
+    const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
+    const MemorySchedulerAlgorithm& algorithm = {});
+
+// Computes the schedule for a single computation.
+// Currently only used by the GPU backend.
+StatusOr<HloInstructionSequence> ScheduleComputation(
+    const HloComputation& computation,
+    const LogicalBuffer::SizeFunction& size_function);
+
+// A pass which schedules the HLO instructions in a module. The HloModule's
+// schedule field is set to the resulting HloSchedule using
+// HloModule::set_schedule.
+class HloMemoryScheduler : public HloPassInterface {
+ public:
+  // size_function is the function returning the number of bytes required for a
+  // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
+  // specified, then DefaultMemoryScheduler is used.
+  HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function,
+                     const MemorySchedulerAlgorithm& algorithm = {});
+  ~HloMemoryScheduler() override = default;
+  absl::string_view name() const override { return "hlo-memory-scheduler"; }
+
+  StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  LogicalBuffer::SizeFunction size_function_;
+  MemorySchedulerAlgorithm algorithm_;
+};
+
+// A trivial pass which clears the schedule currently set on the
+// HloModule. After this pass runs HloModudle::has_schedule will return false.
+class HloDescheduler : public HloPassInterface {
+ public:
+  HloDescheduler() = default;
+  ~HloDescheduler() override = default;
+  absl::string_view name() const override { return "hlo-descheduler"; }
+
+  StatusOr<bool> Run(HloModule* module) override;
+};
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
new file mode 100644
index 0000000..1b9e9bf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -0,0 +1,432 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/heap_simulator.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloSchedulingTest : public HloTestBase {};
+
+TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
+  // Tests scheduling of the following HLO code:
+  //
+  //   %ab = abs(%param)
+  //   %exp = exp(%param)
+  //   %add = add(%ab, %exp)
+  //   %negate = negate(%exp)
+  //   %sub = subtract(%add, %negate)
+  //
+  // %add should be scheduled before %negate because %add is the last (and only)
+  // use of %ab. Scheduling %add first then frees up %ab's buffer.
+  const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
+  auto builder = HloComputation::Builder(TestName());
+  auto param =
+      builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
+  auto ab = builder.AddInstruction(
+      HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
+  auto exp = builder.AddInstruction(
+      HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
+
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
+  auto negate = builder.AddInstruction(
+      HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
+  auto sub = builder.AddInstruction(
+      HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
+
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+
+  HloMemoryScheduler scheduler([](const BufferValue& buffer) {
+    return ShapeUtil::ByteSizeOf(buffer.shape());
+  });
+  ASSERT_FALSE(module->has_schedule());
+  TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
+  EXPECT_TRUE(changed);
+  ASSERT_TRUE(module->has_schedule());
+  TF_ASSERT_OK(module->schedule().Verify());
+
+  // Verify that all instructions are in the sequence.
+  const std::vector<const HloInstruction*>& sequence =
+      module->schedule().sequence(module->entry_computation()).instructions();
+  EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
+
+  // The first instruction should be the parameter and the last the root "sub".
+  EXPECT_EQ(param, sequence.front());
+  EXPECT_EQ(sub, sequence.back());
+
+  SequentialHloOrdering ordering(module->schedule());
+  EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
+
+  // Clear the schedule using the descheduling pass.
+  HloDescheduler descheduler;
+  EXPECT_TRUE(module->has_schedule());
+  TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
+                          descheduler.Run(module.get()));
+  EXPECT_TRUE(descheduler_changed);
+  EXPECT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
+  const char* module_str = R"(
+HloModule test_aliasing_module
+
+ENTRY root {
+  param = s32[1000] parameter(0)
+  p0 = s32[1000] copy(param)
+  p1 = s32[1000] copy(param)
+  t = (s32[1000], s32[1000]) tuple(p0, p1)
+  a = s32[1000] get-tuple-element(t), index=0
+  b = s32[1000] get-tuple-element(t), index=1
+  c = s32[1000] add(a, b)
+  d = s32[1000] add(c, b)
+  e = s32[1000] add(c, c)
+  f = s32[1000] add(e, e)
+  ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
+})";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+
+  auto size_fn = [](const BufferValue& buffer) {
+    return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+  };
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, size_fn, ListMemoryScheduler));
+  // Verify that all instructions are in the sequence.
+  const std::vector<const HloInstruction*>& sequence =
+      schedule.sequence(module->entry_computation()).instructions();
+  EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
+
+  std::unordered_map<string, const HloInstruction*> instructions_by_name;
+  for (const HloInstruction* instruction : sequence) {
+    instructions_by_name[instruction->name()] = instruction;
+  }
+
+  // The first instruction should be the parameter and the last the root.
+  EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
+  EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
+
+  // Instructions "d" and "e" will both be schedulable at the same time, but
+  // instruction "d" allows us to free the buffer of "p1", so the list scheduler
+  // should prefer it.
+  SequentialHloOrdering ordering(schedule);
+  EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
+                                      instructions_by_name.at("e")));
+}
+
+TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
+  // %WhileCond (cond_param: f32[4]) -> pred[] {
+  //   %cond_param = f32[4]{0} parameter(0)
+  //   %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
+  //   ROOT %not-equal-to = pred[] not-equal-to(
+  //     f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
+  // }
+  // %WhileBody (body_param: f32[4]) -> f32[4] {
+  //   %body_param = f32[4]{0} parameter(0)
+  //   %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+  //   ROOT %subtract = f32[4]{0} subtract(
+  //     f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
+  // }
+  // %ListAccountsForSubcomputations () -> f32[2,4] {
+  //   %constant.3 = f32[2,4]{1,0} constant(
+  //     f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
+  //   %transpose = f32[2,4]{1,0} transpose(
+  //     f32[2,4]{1,0} %constant.3), dimensions={0,1}
+  //   %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
+  //   %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
+  //      condition=%WhileCond,
+  //      body=%WhileBody
+  //   %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
+  //   ROOT %add = f32[2,4]{1,0} add(
+  //     f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
+  // }
+
+  auto module = CreateNewModule();
+  const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+  const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+  // param != 0
+  // Needs 17 bytes
+  auto cond_builder = HloComputation::Builder("WhileCond");
+  HloInstruction* cond_param = cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+  HloInstruction* zero_vector =
+      cond_builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
+  cond_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+  auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+  // param - 1
+  // Needs 16 bytes
+  auto body_builder = HloComputation::Builder("WhileBody");
+  HloInstruction* body_param = body_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "body_param"));
+  HloInstruction* one_vector =
+      body_builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+  body_builder.AddInstruction(HloInstruction::CreateBinary(
+      r1f32, HloOpcode::kSubtract, body_param, one_vector));
+  auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+  // transpose(matrix) + bcast(while)
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* while_init =
+      builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+  // Creates 16 bytes, ignoring subcomputations
+  HloInstruction* while_loop =
+      builder.AddInstruction(HloInstruction::CreateWhile(
+          r1f32, cond_computation, body_computation, while_init));
+
+  // Creates 32 bytes and frees 16
+  HloInstruction* bcast = builder.AddInstruction(
+      HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
+
+  HloInstruction* matrix = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
+          {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
+  // Creates 32 bytes
+  HloInstruction* transpose = builder.AddInstruction(
+      HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
+
+  // Creates 32 bytes and frees 64
+  HloInstruction* add = builder.AddInstruction(
+      HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
+
+  module->AddEntryComputation(builder.Build());
+
+  auto size_fn = [](const BufferValue& buffer) {
+    return ShapeUtil::ByteSizeOf(buffer.shape());
+  };
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, size_fn, ListMemoryScheduler));
+  // Verify that all instructions are in the sequence.
+  auto entry_computation = module->entry_computation();
+  EXPECT_EQ(entry_computation->instruction_count(),
+            schedule.sequence(entry_computation).size());
+  SequentialHloOrdering ordering(schedule);
+  // This schedule is an example of List's greedy heuristics being suboptimal.
+  // The while_loop is more expensive than transpose, so it would have been
+  // better to schedule it first, instead of during the busy time.
+  EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop));
+  EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
+  EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
+  EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+
+  tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+  memory_by_computation[cond_computation] = 17;
+  memory_by_computation[body_computation] = 16;
+  std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+      TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+  // HeapSimulator doesn't account for subcomputations
+  EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+                    *entry_computation, schedule.sequence(entry_computation),
+                    *points_to_analysis, size_fn)
+                    .ValueOrDie());
+  // HeapSimulator accounts for subcomputations. The output buffer is aliased,
+  // so we don't double count.
+  EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
+                    *entry_computation, schedule.sequence(entry_computation),
+                    *points_to_analysis, size_fn, &memory_by_computation)
+                    .ValueOrDie());
+}
+
+TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
+  auto builder = HloComputation::Builder(TestName());
+  const auto TUPLE_SIZE = 1;
+  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
+
+  // Wrap lit in abs because constants are considered free by
+  // IgnoreInstruction, and it skews the accounting.
+  auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
+  auto abs_const = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
+
+  auto abs_abs1 = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
+      absl::Span<HloInstruction* const>({abs_abs1})));
+  auto tuple_elm = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+  auto abs_abs2 = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
+
+  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
+                                                      tuple_elm, abs_abs2));
+
+  auto module = CreateNewModule();
+  module->AddEntryComputation(builder.Build());
+  TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+                          ScheduleModule(*module,
+                                         [](const BufferValue& buffer) {
+                                           return ShapeUtil::ByteSizeOf(
+                                               buffer.shape(), TUPLE_SIZE);
+                                         },
+                                         ListMemoryScheduler));
+
+  // Verify that all instructions are in the sequence.
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            schedule.sequence(module->entry_computation()).size());
+  SequentialHloOrdering ordering(schedule);
+  // tuple allocates the tuple buffer and doesn't free anything.
+  // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
+  // abs_abs2 should be scheduled before tuple by List.
+  EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
+}
+
+TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
+  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
+  HloComputation::Builder builder(TestName());
+
+  auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
+  auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
+  auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
+
+  auto add = builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
+  auto mul = builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
+  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
+
+  auto tuple_elm = builder.AddInstruction(
+      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
+
+  auto exp = builder.AddInstruction(
+      HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
+
+  builder.AddInstruction(
+      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
+
+  auto module = CreateNewModule();
+  auto* computation = module->AddEntryComputation(builder.Build());
+
+  auto fusion = computation->CreateFusionInstruction(
+      {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
+
+  TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
+                          ScheduleModule(*module,
+                                         [](const BufferValue& buffer) {
+                                           return ShapeUtil::ByteSizeOf(
+                                               buffer.shape(), 2);
+                                         },
+                                         ListMemoryScheduler));
+
+  // Verify that all instructions are in the sequence.
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            schedule.sequence(module->entry_computation()).size());
+  SequentialHloOrdering ordering(schedule);
+  // fusion allocates memory for the tuple elements and doesn't free anything,
+  // so it's more expensive than exp.
+  EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
+}
+
+TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
+  auto module = CreateNewModule();
+  const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+
+  // param != 0
+  // Needs 17 bytes
+  auto cond_builder = HloComputation::Builder("WhileCond");
+  HloInstruction* cond_param = cond_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+  HloInstruction* zero_vector =
+      cond_builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
+  cond_builder.AddInstruction(HloInstruction::CreateBinary(
+      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+  auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+  // param - 1
+  // Needs 16 bytes
+  auto body_builder = HloComputation::Builder("WhileBody");
+  HloInstruction* body_param = body_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, r1f32, "body_param"));
+  HloInstruction* one_vector =
+      body_builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+  body_builder.AddInstruction(HloInstruction::CreateBinary(
+      r1f32, HloOpcode::kSubtract, body_param, one_vector));
+  auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* while_init =
+      builder.AddInstruction(HloInstruction::CreateConstant(
+          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+  // Creates 16 bytes, ignoring subcomputations
+  builder.AddInstruction(HloInstruction::CreateWhile(
+      r1f32, cond_computation, body_computation, while_init));
+
+  module->AddEntryComputation(builder.Build());
+
+  auto size_fn = [](const BufferValue& buffer) {
+    return ShapeUtil::ByteSizeOf(buffer.shape());
+  };
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, size_fn, ListMemoryScheduler));
+  // Verify that all instructions are in the sequence.
+  auto entry_computation = module->entry_computation();
+  EXPECT_EQ(module->entry_computation()->instruction_count(),
+            schedule.sequence(module->entry_computation()).size());
+
+  tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+  memory_by_computation[cond_computation] = 17;
+  memory_by_computation[body_computation] = 16;
+  std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+      TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+  // HeapSimulator doesn't account for subcomputations
+  EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
+                    *entry_computation, schedule.sequence(entry_computation),
+                    *points_to_analysis, size_fn)
+                    .ValueOrDie());
+  // HeapSimulator accounts for subcomputations. Cond is the largest one.
+  // The output buffer of the while is aliased.
+  EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
+                    *entry_computation, schedule.sequence(entry_computation),
+                    *points_to_analysis, size_fn, &memory_by_computation)
+                    .ValueOrDie());
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 3a1bc4e..b3949f3 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -26,6 +26,7 @@
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -50,9 +51,16 @@
   return const_cast<HloInstruction*>(hlo);
 }
 
+Status HloModule::set_schedule(HloSchedule schedule) {
+  TF_RET_CHECK(schedule.module() == this);
+  TF_RETURN_IF_ERROR(schedule.Verify());
+  schedule_ = std::move(schedule);
+  return Status::OK();
+}
+
 HloComputation* HloModule::AddComputationInternal(
     std::unique_ptr<HloComputation> computation, bool is_entry,
-    bool uniquify_names) {
+    bool uniquify_identifiers) {
   if (is_entry) {
     CHECK_EQ(nullptr, entry_computation_);
     entry_computation_ = computation.get();
@@ -65,30 +73,36 @@
     }
   }
 
-  if (uniquify_names) {
+  if (uniquify_identifiers) {
     computation->UniquifyName(&computation_name_uniquer_);
     for (auto* instruction : computation->instructions()) {
       instruction->UniquifyName(&instruction_name_uniquer_);
     }
+
+    // Pick unique IDs for each instruction.
+    for (auto* instruction : computation->instructions()) {
+      instruction->SetUniqueId(NewUniqueInstructionId());
+    }
+    // Set unique id to this computation.
+    CHECK_NE(computation->root_instruction()->unique_id(), -1)
+        << "Root has no valid id: " << computation->ToString();
+    computation->SetUniqueId(computation->root_instruction()->unique_id());
   } else {
     // Don't uniquify the names of the computation or instruction, but we must
     // run the names through the uniquifiers to prevent future name collisions
-    // for computations and instructions created later.
+    // for computations and instructions created later. Also, set the
+    // next_unique_id_ to the one greater than the max unique id of any
+    // instruction (or the computation) to avoid ID collisions.
     computation_name_uniquer_.GetUniqueName(computation->name());
     for (auto* instruction : computation->instructions()) {
       instruction_name_uniquer_.GetUniqueName(instruction->name());
+      next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+    }
+    if (next_unique_id_ < computation->unique_id() + 1) {
+      next_unique_id_ = computation->unique_id() + 1;
     }
   }
 
-  // Pick unique IDs for each instruction.
-  for (auto* instruction : computation->instructions()) {
-    instruction->SetUniqueId(NewUniqueInstructionId());
-  }
-  // Set unique id to this computation.
-  CHECK_NE(computation->root_instruction()->unique_id(), -1)
-      << "Root has no valid id: " << computation->ToString();
-  computation->SetUniqueId(computation->root_instruction()->unique_id());
-
   computation->set_parent(this);
   computations_.push_back(std::move(computation));
   return computations_.back().get();
@@ -97,7 +111,7 @@
 HloComputation* HloModule::AddEntryComputation(
     std::unique_ptr<HloComputation> computation) {
   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
-                                /*uniquify_names=*/true);
+                                /*uniquify_identifiers=*/true);
 }
 
 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -114,7 +128,7 @@
 HloComputation* HloModule::AddEmbeddedComputation(
     std::unique_ptr<HloComputation> computation) {
   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
-                                /*uniquify_names=*/true);
+                                /*uniquify_identifiers=*/true);
 }
 
 void HloModule::ReplaceComputations(
@@ -198,12 +212,23 @@
 
 string HloModule::ToString(const HloPrintOptions& options) const {
   std::ostringstream s;
-  s << "HloModule " << name() << "\n\n";
+  s << "HloModule " << name();
+  if (has_schedule()) {
+    TF_CHECK_OK(schedule().Verify());
+    s << ", is_scheduled=true";
+  }
+  s << "\n\n";
   for (const HloComputation* computation : MakeComputationPostOrder()) {
     if (computation == entry_computation()) {
       s << "ENTRY ";
     }
-    s << computation->ToString(options) << "\n\n";
+    if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+      s << computation->ToString(
+               options, schedule().sequence(computation).instructions())
+        << "\n\n";
+    } else {
+      s << computation->ToString(options) << "\n\n";
+    }
   }
   return s.str();
 }
@@ -221,12 +246,18 @@
     }
     proto.add_computations()->Swap(&computation_proto);
   }
+  if (has_schedule()) {
+    *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+  }
   return proto;
 }
 
 /* static */
 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
     const HloModuleProto& proto, const HloModuleConfig& module_config) {
+  VLOG(2) << "CreateFromProto()";
+  XLA_VLOG_LINES(2, proto.DebugString());
+
   // The ProgramShape in the passed in module config must match the shapes of
   // the entry parameters and root.
   TF_RET_CHECK(proto.has_program_shape())
@@ -290,25 +321,42 @@
     // Don't uniquify names because we want names to be stable across
     // serialization and deserialization.
     module->AddComputationInternal(std::move(computation), is_entry,
-                                   /*uniquify_names=*/false);
+                                   /*uniquify_identifiers=*/false);
   }
   TF_RET_CHECK(module->entry_computation_ != nullptr);
 
-  // Because we didn't uniquify the names, double-check that the instruction and
-  // computation names are unique from the proto.
+  // Because we didn't uniquify the names or the ids, double-check that the
+  // instruction and computation names and ids are unique from the proto.
   tensorflow::gtl::FlatSet<string> computation_names;
   tensorflow::gtl::FlatSet<string> instruction_names;
+  tensorflow::gtl::FlatSet<int> computation_ids;
+  tensorflow::gtl::FlatSet<int> instruction_ids;
   for (HloComputation* computation : module->computations()) {
     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
         << "Computation name is not unique: " << computation->name();
     computation_names.insert(computation->name());
+
+    TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+        << "Computation id is not unique: " << computation->unique_id();
+    computation_ids.insert(computation->unique_id());
     for (HloInstruction* instruction : computation->instructions()) {
       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
           << "Instruction name is not unique: " << instruction->name();
       instruction_names.insert(instruction->name());
+
+      TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+          << "Instruction id is not unique: " << instruction->unique_id();
+      instruction_ids.insert(instruction->unique_id());
     }
   }
 
+  if (proto.has_schedule()) {
+    TF_ASSIGN_OR_RETURN(
+        HloSchedule schedule,
+        HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+    TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+  }
+
   return std::move(module);
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3c33714..3bc2d13 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@
 #include <vector>
 
 #include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/iterator_util.h"
 #include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,6 +33,7 @@
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/name_uniquer.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -235,10 +237,23 @@
   StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
       const HloInstruction* hlo);
 
+  // Sets the schedule of the module to the given schedule.
+  Status set_schedule(HloSchedule schedule);
+
+  // Clears the schedule of the module.
+  void clear_schedule() { schedule_.reset(); }
+
+  // Returns true if the module has a schedule set.
+  bool has_schedule() const { return schedule_.has_value(); }
+
+  // Returns the schedue of the module. CHECK fails if no schedule is set.
+  const HloSchedule& schedule() const { return *schedule_; }
+  HloSchedule& schedule() { return *schedule_; }
+
  private:
   HloComputation* AddComputationInternal(
       std::unique_ptr<HloComputation> computation, bool is_entry,
-      bool uniquify_names);
+      bool uniquify_identifiers);
 
   const string name_;
   HloModuleConfig config_;
@@ -262,6 +277,11 @@
   static std::atomic<int> next_unique_module_id_;
   // A unique id to label modules with.
   int unique_id_;
+
+  // The HloSchedule of the module. The schedule if it exists contains a
+  // sequential order of instructions for each non-fusion computation in the
+  // module.
+  absl::optional<HloSchedule> schedule_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 3f1e1cc..68c1883 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -106,9 +106,6 @@
 
   absl::optional<ComputationLayout> entry_computation_layout_;
 
-  // Whether this is a 'host module'.
-  bool is_host_module_ = false;
-
   // Module/graph-level seed handle.
   uint64 seed_ = 0;
 
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index 98d2031..f7be5ca 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -36,23 +36,6 @@
 
 namespace {
 
-bool HasSendRecv(HloComputation* computation) {
-  for (auto* instruction : computation->instructions()) {
-    if (instruction->opcode() == HloOpcode::kSend ||
-        instruction->opcode() == HloOpcode::kSendDone ||
-        instruction->opcode() == HloOpcode::kRecv ||
-        instruction->opcode() == HloOpcode::kRecvDone) {
-      return true;
-    }
-    for (auto* sub_computation : instruction->called_computations()) {
-      if (HasSendRecv(sub_computation)) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
 StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
   bool changed = false;
   for (auto* computation : module->computations()) {
@@ -68,9 +51,10 @@
 
       if (!ShapeUtil::IsTuple(xla_while->shape()) ||
           while_body_root->opcode() != HloOpcode::kTuple ||
-          HasSendRecv(while_body_comp)) {
+          while_body_comp->HasSideEffect() ||
+          xla_while->while_condition()->HasSideEffect()) {
         // Only run DCE on tuple-shaped while loops where body root is Tuple,
-        // with no send/recv instructions.
+        // with no I/O instructions.
         VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
         continue;
       }
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e..bf66cc6 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,77 @@
                                                   "while.2", 1));
 }
 
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+  auto module = ParseHloString(R"(
+  HloModule OutfeedLoop
+  WhileBody {
+    body_param = (s32[]) parameter(0)
+    token = token[] after-all()
+    constant.2 = s32[] constant(2)
+    outfeed_tuple = (s32[]) outfeed(constant.2, token)
+    get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    ROOT tuple = (s32[]) tuple(add)
+  }
+  WhileCondition {
+    cond_param = (s32[]) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+    constant.2 = s32[] constant(10)
+    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+  }
+  ENTRY SimpleLoop {
+    constant.3 = s32[] constant(0)
+    tuple.1 = (s32[]) tuple(constant.3)
+    while = (s32[]) while(tuple.1), condition=WhileCondition,
+      body=WhileBody
+    ROOT rtuple = () tuple()
+  })")
+                    .ValueOrDie();
+
+  HloModuleDCE dce;
+  EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+                                                   "while", 0));
+}
+
+// Tests that if a loop variable is not referenced outside of a kWhile, the loop
+// variable changes are not elided within the loop body, if the condition
+// computation uses them.
+TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
+  auto module = ParseHloString(R"(
+  HloModule InfiniteLoop
+  WhileBody {
+    body_param = (s32[], s32[]) parameter(0)
+    get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+    get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
+    constant.1 = s32[] constant(1)
+    add = s32[] add(get-tuple-element.1, constant.1)
+    ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
+  }
+  WhileCondition {
+    cond_param = (s32[], s32[]) parameter(0)
+    get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+    constant.2 = s32[] constant(10)
+    ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+  }
+  ENTRY SimpleLoop {
+    p0 = (s32[]) parameter(0)
+    get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
+    constant.3 = s32[] constant(0)
+    tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
+    while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
+      body=WhileBody
+    ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
+  })")
+                    .ValueOrDie();
+
+  HloModuleDCE dce;
+  EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+  EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+                                                   "while", 0));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc
new file mode 100644
index 0000000..f9b56ef
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+namespace xla {
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+                               std::unique_ptr<HloModule> module)
+    : name_(name) {
+  push_back(std::move(module));
+}
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+                               absl::Span<std::unique_ptr<HloModule>> modules)
+    : name_(name) {
+  for (auto& module : modules) {
+    push_back(std::move(module));
+  }
+}
+
+std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() {
+  std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_);
+
+  // Clear everything so the object state is in a known (empty) state.
+  modules_.clear();
+  module_ptrs_.clear();
+  return ret_modules;
+}
+
+string HloModuleGroup::ToString() const {
+  std::ostringstream s;
+  s << "HloModuleGroup " << name() << "\n\n";
+  for (const HloModule* module : modules()) {
+    s << module->ToString() << "\n";
+  }
+  return s.str();
+}
+
+HloModuleGroupProto HloModuleGroup::ToProto() const {
+  HloModuleGroupProto proto;
+  proto.set_name(name());
+  for (const HloModule* module : modules()) {
+    *proto.add_hlo_modules() = module->ToProto();
+  }
+  return proto;
+}
+
+/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto(
+    const HloModuleGroupProto& proto,
+    absl::Span<const HloModuleConfig> module_configs) {
+  TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty";
+  TF_RET_CHECK(proto.hlo_modules_size() > 0)
+      << "Module group must have at least one HLO module";
+  TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size());
+
+  std::vector<std::unique_ptr<HloModule>> modules;
+  for (int i = 0; i < proto.hlo_modules_size(); ++i) {
+    const HloModuleProto& module_proto = proto.hlo_modules(i);
+    TF_ASSIGN_OR_RETURN(
+        std::unique_ptr<HloModule> module,
+        HloModule::CreateFromProto(module_proto, module_configs[i]));
+    modules.push_back(std::move(module));
+  }
+
+  return HloModuleGroup(proto.name(), absl::MakeSpan(modules));
+}
+
+void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
+  modules_.push_back(std::move(module));
+  module_ptrs_.push_back(modules_.back().get());
+}
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) {
+  out << group.ToString();
+  return out;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h
new file mode 100644
index 0000000..7338be8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+
+// An abstraction representing a ordered set of HLO module built to run
+// concurrently across different devices.
+class HloModuleGroup {
+ public:
+  // Construct an empty module group.
+  explicit HloModuleGroup(absl::string_view name) : name_(name) {}
+
+  // Construct a module group containing a single module.
+  HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
+
+  // Construct a module group containing any number of modules.
+  HloModuleGroup(absl::string_view name,
+                 absl::Span<std::unique_ptr<HloModule>> modules);
+
+  // Returns the modules contained in the group.
+  const std::vector<HloModule*>& modules() const { return module_ptrs_; }
+
+  // Returns a module at a particular index.
+  HloModule& module(int index) const { return *module_ptrs_.at(index); }
+
+  // Add a module to the back of vector of modules in the group.
+  void push_back(std::unique_ptr<HloModule> module);
+
+  // Moves all modules from the group into the returned vector. After this
+  // method runs, the module group will be empty.
+  std::vector<std::unique_ptr<HloModule>> ConsumeModules();
+
+  string name() const { return name_; }
+  string ToString() const;
+
+  // Serialize the module group to/from a proto.
+  HloModuleGroupProto ToProto() const;
+  static StatusOr<HloModuleGroup> CreateFromProto(
+      const HloModuleGroupProto& proto,
+      absl::Span<const HloModuleConfig> module_configs);
+
+ private:
+  string name_;
+
+  // Vector of modules as std::unique_ptrs.
+  std::vector<std::unique_ptr<HloModule>> modules_;
+
+  // Vector of modules as normal pointers. This vector is kept in sync with
+  // modules_ as modules are added to the group with push_back.
+  std::vector<HloModule*> module_ptrs_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
new file mode 100644
index 0000000..ebf790b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class HloModuleGroupTest : public HloTestBase {
+ protected:
+  HloModuleGroupTest() = default;
+};
+
+TEST_F(HloModuleGroupTest, SingleModule) {
+  const string text = R"(
+HloModule simple_module
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  HloModuleGroup group(TestName(), std::move(module));
+
+  EXPECT_EQ(group.modules().size(), 1);
+  EXPECT_THAT(
+      group.module(0).entry_computation()->instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+  TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+                          HloModuleGroup::CreateFromProto(
+                              group.ToProto(), {group.module(0).config()}));
+  EXPECT_EQ(group_copy.modules().size(), 1);
+  EXPECT_THAT(
+      group_copy.module(0).entry_computation()->instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+  std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
+  EXPECT_EQ(modules.size(), 1);
+  EXPECT_EQ(group.modules().size(), 0);
+}
+
+TEST_F(HloModuleGroupTest, MultipleModules) {
+  const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+)";
+  const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+  ROOT %a = f32[] parameter(0)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+                          ParseHloString(text_0));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+                          ParseHloString(text_1));
+  std::vector<std::unique_ptr<HloModule>> modules;
+  modules.push_back(std::move(module_0));
+  modules.push_back(std::move(module_1));
+  HloModuleGroup group(TestName(), absl::MakeSpan(modules));
+  EXPECT_EQ(group.modules().size(), 2);
+  EXPECT_THAT(
+      group.module(0).entry_computation()->instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+  EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+              ::testing::ElementsAre(op::Parameter()));
+
+  TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+                          HloModuleGroup::CreateFromProto(
+                              group.ToProto(), {group.module(0).config(),
+                                                group.module(1).config()}));
+  EXPECT_EQ(group_copy.modules().size(), 2);
+}
+
+TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
+  const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+  %x = f32[] parameter(0)
+  %y = f32[] parameter(1)
+  ROOT %add = f32[] add(%x, %y)
+}
+)";
+  const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+  ROOT %a = f32[] parameter(0)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+                          ParseHloString(text_0));
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+                          ParseHloString(text_1));
+  HloModuleGroup group(TestName());
+  group.push_back(std::move(module_0));
+  group.push_back(std::move(module_1));
+
+  EXPECT_EQ(group.modules().size(), 2);
+  EXPECT_THAT(
+      group.module(0).entry_computation()->instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+  EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+              ::testing::ElementsAre(op::Parameter()));
+}
+
+}  // namespace
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4bc1bac..39f38b4 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -19,10 +19,13 @@
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/test.h"
 
@@ -30,6 +33,8 @@
 
 namespace {
 
+namespace op = ::xla::testing::opcode_matchers;
+
 class HloModuleTest : public HloTestBase {
  protected:
   HloModuleTest() {}
@@ -194,6 +199,153 @@
   EXPECT_NE(module_a->unique_id(), module_b->unique_id());
 }
 
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+  const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %x = f32[2,4]{1,0} parameter(1)
+  %y = f32[2,4]{1,0} parameter(2)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_FALSE(module->has_schedule());
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module_copy,
+      HloModule::CreateFromProto(module->ToProto(), module->config()));
+  ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+  const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %x = f32[2,4]{1,0} parameter(1)
+  %y = f32[2,4]{1,0} parameter(2)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_TRUE(module->has_schedule());
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module_copy,
+      HloModule::CreateFromProto(module->ToProto(), module->config()));
+  ASSERT_TRUE(module_copy->has_schedule());
+  TF_ASSERT_OK(module_copy->schedule().Verify());
+  EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+  ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+      module_copy->entry_computation()));
+  EXPECT_THAT(
+      module_copy->schedule()
+          .sequence(module_copy->entry_computation())
+          .instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+                             op::Broadcast(), op::Multiply(), op::Add()));
+}
+
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+  // Verify that serializing then deserializing an HLO proto preserves the
+  // unique IDs of the instruction and module.
+  const string text =
+      R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+  lhs = f32[] parameter(0)
+  rhs = f32[] parameter(1)
+  ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+  input = f32[8,16,256]{2,1,0} parameter(0)
+  constant = f32[] constant(0)
+  ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+
+  // Perform various transformations on the graph:
+  //
+  //  * clone the reduction function
+  //  * replace use of reduction function with the clone.
+  //  * add a random instruction to the entry computation.
+  //
+  // This will create instruction and computation IDs which are interesting:
+  // not consecutive and not densely packed.
+  HloComputation* entry = module->entry_computation();
+  HloInstruction* root = entry->root_instruction();
+  HloComputation* reduction = root->to_apply();
+  HloComputation* reduction_clone =
+      module->AddEmbeddedComputation(reduction->Clone());
+  root->set_to_apply(reduction_clone);
+  TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+  HloInstruction* negate = entry->AddInstruction(
+      HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+  entry->set_root_instruction(negate);
+
+  // Schedule the transformed module, this verifies that the serialized schedule
+  // is robust against non-consecutive IDs as well (b/114712358).
+  auto size_fn = [](const BufferValue& buffer) {
+    return ShapeUtil::ByteSizeOf(buffer.shape());
+  };
+  HloMemoryScheduler scheduler(size_fn);
+  TF_ASSERT_OK(scheduler.Run(module.get()).status());
+  ASSERT_TRUE(module->has_schedule());
+
+  // Serialize and deserialize and verify that the instruction and computations
+  // unique ids are the same.
+  TF_ASSERT_OK_AND_ASSIGN(
+      std::unique_ptr<HloModule> module_copy,
+      HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+  // The module IDs should *not* be the same because module ids must be globally
+  // unique.
+  EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+  // Verify that the computations and instructions all have the same unique id.
+  auto computation_copy_it = module_copy->computations().begin();
+  for (const HloComputation* computation_orig : module->computations()) {
+    const HloComputation* computation_copy = *computation_copy_it++;
+    EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+        << absl::StrFormat(
+               "ID of original computation %s != ID of deserialized "
+               "computation %s: %d != %d",
+               computation_orig->name(), computation_copy->name(),
+               computation_orig->unique_id(), computation_copy->unique_id());
+
+    auto instruction_copy_it = computation_copy->instructions().begin();
+    for (const HloInstruction* instruction_orig :
+         computation_orig->instructions()) {
+      const HloInstruction* instruction_copy = *instruction_copy_it++;
+      EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+          << absl::StrFormat(
+                 "ID of original instruction %s != ID of deserialized "
+                 "instruction %s: %d != %d",
+                 instruction_orig->name(), instruction_copy->name(),
+                 instruction_orig->unique_id(), instruction_copy->unique_id());
+    }
+  }
+
+  // Verify that the next unique ID which the module would have handed out is
+  // greater than the unique id of any instruction.
+  int next_id = module_copy->NewUniqueInstructionId();
+  for (const HloComputation* computation : module_copy->computations()) {
+    for (const HloInstruction* instruction : computation->instructions()) {
+      EXPECT_GT(next_id, instruction->unique_id());
+    }
+  }
+}
+
 }  // namespace
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 0581d5c..f1dc08b 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "absl/strings/str_join.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -252,6 +253,12 @@
     VLOG(4) << a << " not defined before " << b;
     return false;
   }
+
+  if (a.live_out_of_module()) {
+    VLOG(4) << a << " is live out of module and defined before " << b;
+    return false;
+  }
+
   // All uses of 'a' must be before 'b' is defined.
   for (const HloUse& use : a.uses()) {
     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
@@ -264,6 +271,18 @@
       return false;
     }
   }
+
+  if (a.instruction()->parent() == b.instruction()->parent()) {
+    for (const HloPosition& position : a.positions()) {
+      if (position.instruction ==
+          a.instruction()->parent()->root_instruction()) {
+        VLOG(4) << a << " is live out of computation and defined before " << b
+                << " which is in same computation";
+        return false;
+      }
+    }
+  }
+
   return true;
 }
 
@@ -274,23 +293,6 @@
          !LiveRangeStrictlyBefore(b, a, dataflow);
 }
 
-HloOrderingProto HloOrdering::ToProto() const {
-  HloOrderingProto proto;
-  for (const auto& computation : module_->computations()) {
-    const std::vector<const HloInstruction*>* sequence =
-        SequentialOrder(*computation);
-    if (sequence != nullptr) {
-      HloOrderingProto::SequentialComputation* proto_computation =
-          proto.add_sequential_computations();
-      proto_computation->set_computation_name(computation->name());
-      for (const HloInstruction* instruction : *sequence) {
-        *proto_computation->add_instruction_names() = instruction->name();
-      }
-    }
-  }
-  return proto;
-}
-
 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
     : HloOrdering(module) {}
 
@@ -336,15 +338,24 @@
   return ToStringHelper("DependencyHloOrdering");
 }
 
-SequentialHloOrdering::SequentialHloOrdering(
-    const HloModule* module, const HloModuleSequence& module_sequence)
-    : HloOrdering(module), module_sequence_(module_sequence) {
+SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
+    : HloOrdering(schedule.module()), schedule_(schedule) {
+  Initialize();
+}
+
+SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
+    : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
+  Initialize();
+}
+
+void SequentialHloOrdering::Initialize() {
   // Create a map from instruction to its order position.
-  for (auto computation_order : module_sequence_) {
-    const std::vector<const HloInstruction*>& order = computation_order.second;
+  TF_DCHECK_OK(schedule_.Verify());
+  for (const auto& computation_sequence : schedule_.sequences()) {
+    const std::vector<const HloInstruction*>& order =
+        computation_sequence.second.instructions();
     for (int i = 0; i < order.size(); ++i) {
-      DCHECK_EQ(0, order_position_.count(order[i]));
-      order_position_.emplace(order[i], i);
+      InsertOrDie(&order_position_, order[i], i);
     }
   }
 }
@@ -362,49 +373,13 @@
 const std::vector<const HloInstruction*>*
 SequentialHloOrdering::SequentialOrder(
     const HloComputation& computation) const {
-  auto find_it = module_sequence_.find(&computation);
-  return find_it == module_sequence_.end() ? nullptr : &find_it->second;
+  return schedule_.is_computation_scheduled(&computation)
+             ? &schedule_.sequence(&computation).instructions()
+             : nullptr;
 }
 
 string SequentialHloOrdering::ToString() const {
-  std::vector<string> pieces;
-  pieces.push_back("SequentialHloOrdering");
-  for (auto* computation : module_->computations()) {
-    pieces.push_back(
-        absl::StrFormat("computation %s order:", computation->name()));
-    // Gather all instructions in the module sequence for this computation and
-    // sort them by their position.
-    std::vector<const HloInstruction*> instructions;
-    for (auto& instruction_position : order_position_) {
-      const HloInstruction* instruction = instruction_position.first;
-      if (instruction->parent() == computation) {
-        instructions.push_back(instruction);
-      }
-    }
-    std::sort(instructions.begin(), instructions.end(),
-              [this](const HloInstruction* a, const HloInstruction* b) {
-                return order_position_.at(a) < order_position_.at(b);
-              });
-    for (auto instruction : instructions) {
-      pieces.push_back(absl::StrFormat("  %s", instruction->name()));
-    }
-  }
-  return absl::StrJoin(pieces, "\n");
-}
-
-std::ostream& operator<<(
-    std::ostream& out,
-    const SequentialHloOrdering::HloModuleSequence& module_sequence) {
-  for (auto computation_pair : module_sequence) {
-    const HloComputation* computation = computation_pair.first;
-    const std::vector<const HloInstruction*>& computation_sequence =
-        computation_pair.second;
-    out << "Computation " << computation->name() << ":\n";
-    for (auto* instruction : computation_sequence) {
-      out << "  " << instruction->name() << "\n";
-    }
-  }
-  return out;
+  return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index 985f3fa..b0361c3 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -25,6 +25,7 @@
 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/hlo_value.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
@@ -71,10 +72,6 @@
 
   virtual string ToString() const = 0;
 
-  // Returns the serialized representation of this ordering.
-  // Only sequential computation orders are represented.
-  HloOrderingProto ToProto() const;
-
  protected:
   // Returns true if instruction 'a' executes before instruction 'b'.
   // Precondition: 'a' and 'b' are in the same computation.
@@ -183,17 +180,8 @@
 // interference is reduced relative to DependencyHloOrdering.
 class SequentialHloOrdering : public HloOrdering {
  public:
-  // TODO(dimvar): HloModuleSequence is not a good name because it sounds like
-  // a sequence of modules, instead of a map of schedules for all computations
-  // in a module. We should change it at some point.
-  //
-  // A sequence of instructions for each computation in the module.
-  using HloModuleSequence =
-      tensorflow::gtl::FlatMap<const HloComputation*,
-                               std::vector<const HloInstruction*>>;
-
-  SequentialHloOrdering(const HloModule* module,
-                        const HloModuleSequence& module_sequence);
+  SequentialHloOrdering(const HloSchedule& schedule);
+  SequentialHloOrdering(HloSchedule&& schedule);
   ~SequentialHloOrdering() override = default;
 
   // Returns the sequential instruction order for the given computation.
@@ -203,10 +191,12 @@
   string ToString() const override;
 
  protected:
+  void Initialize();
+
   bool ExecutesBeforeInSameComputation(const HloInstruction* a,
                                        const HloInstruction* b) const override;
 
-  const HloModuleSequence module_sequence_;
+  const HloSchedule schedule_;
 
   // The position of every instruction in the HLO module in its respective
   // computation sequence (a value of zero indicates the instruction is first in
@@ -217,10 +207,6 @@
   tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
 };
 
-std::ostream& operator<<(
-    std::ostream& out,
-    const SequentialHloOrdering::HloModuleSequence& module_sequence);
-
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 126d3a2..00970bc 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -23,11 +23,12 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 
 namespace xla {
 namespace {
@@ -376,5 +377,104 @@
                                        dataflow->GetValueDefinedAt(add_3)));
 }
 
+TEST_F(HloOrderingTest,
+       ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) {
+  // Tests that values live out of the module should interfere with values
+  // defined after the root instruction. That is:
+  //
+  //   %param = param(0)
+  //   ROOT %root = negate(%param)
+  //   %dead = Constant(123.0)
+  //
+  // %root should interfere with %dead.
+  auto module = CreateNewModule();
+  const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* param = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, scalar_shape, "param"));
+  HloInstruction* root = builder.AddInstruction(
+      HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+  HloInstruction* dead = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+  HloComputation* entry =
+      module->AddEntryComputation(builder.Build(/*root_instruction=*/root));
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(entry, {param, root, dead});
+  TF_ASSERT_OK(schedule.Verify());
+  SequentialHloOrdering ordering(schedule);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+                          HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+  EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+  EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+  EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+      dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+      *dataflow));
+
+  EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+                                    dataflow->GetValueDefinedAt(dead),
+                                    *dataflow));
+}
+
+TEST_F(HloOrderingTest,
+       ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) {
+  // Tests that values live out of a computation should interfere with values
+  // defined after the root instruction of the computation. That is:
+  //
+  // subcomputation:
+  //   %param = param(0)
+  //   ROOT %root = negate(%param)
+  //   %dead = Constant(123.0)
+  //
+  // entry computation:
+  //   %c = constant(42.0)
+  //   ROOT %call = call({%c}), subcomputation
+  //
+  // %root should interfere with %dead.
+  auto module = CreateNewModule();
+  const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+  auto subbuilder = HloComputation::Builder(TestName() + ".sub");
+  HloInstruction* param = subbuilder.AddInstruction(
+      HloInstruction::CreateParameter(0, scalar_shape, "param"));
+  HloInstruction* root = subbuilder.AddInstruction(
+      HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
+  HloInstruction* dead = subbuilder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+  HloComputation* subcomputation = module->AddEmbeddedComputation(
+      subbuilder.Build(/*root_instruction=*/root));
+
+  auto builder = HloComputation::Builder(TestName());
+  HloInstruction* c = builder.AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+  HloInstruction* call = builder.AddInstruction(
+      HloInstruction::CreateCall(scalar_shape, {c}, subcomputation));
+  HloComputation* entry = module->AddEntryComputation(builder.Build());
+
+  HloSchedule schedule(module.get());
+  schedule.set_sequence(subcomputation, {param, root, dead});
+  schedule.set_sequence(entry, {c, call});
+  TF_ASSERT_OK(schedule.Verify());
+  SequentialHloOrdering ordering(schedule);
+
+  TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+                          HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+  EXPECT_TRUE(ordering.ExecutesBefore(root, dead));
+  EXPECT_FALSE(ordering.ExecutesBefore(dead, root));
+
+  EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+      dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead),
+      *dataflow));
+
+  EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root),
+                                    dataflow->GetValueDefinedAt(dead),
+                                    *dataflow));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ea8e6a2..11caa89 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -44,6 +45,20 @@
 
 const double kF16max = 65504;
 
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+  HloSchedule schedule(module);
+  for (const HloComputation* computation : module->computations()) {
+    if (!computation->IsFusionComputation()) {
+      for (const HloInstruction* instruction : computation->instructions()) {
+        schedule.GetOrCreateSequence(computation).push_back(instruction);
+      }
+    }
+  }
+  return schedule;
+}
+
 // Parser for the HloModule::ToString() format text.
 class HloParser {
  public:
@@ -90,16 +105,13 @@
                             string* root_name);
   bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
   bool ParseControlPredecessors(HloInstruction* instruction);
-  bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
-  bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
-  bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
-                            const Shape& shape);
-  bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
-  bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
-                          const Shape& shape);
+  bool ParseLiteral(Literal* literal, const Shape& shape);
+  bool ParseTupleLiteral(Literal* literal, const Shape& shape);
+  bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
+  bool ParseDenseLiteral(Literal* literal, const Shape& shape);
+  bool ParseSparseLiteral(Literal* literal, const Shape& shape);
   template <typename LiteralNativeT>
-  bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
-                                const Shape& shape);
+  bool ParseSparseLiteralHelper(Literal* literal, const Shape& shape);
 
   // Sets the sub-value of literal at the given index to the given value. The
   // literal's shape must have the default layout.
@@ -221,7 +233,7 @@
   bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
 
   bool ParseSliceRanges(SliceRanges* result);
-  bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result);
+  bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
   bool ParseInt64List(const TokKind start, const TokKind end,
                       const TokKind delim,
                       std::vector<tensorflow::int64>* result);
@@ -240,7 +252,7 @@
   bool ParseFftType(FftType* result);
   bool ParseFusionKind(HloInstruction::FusionKind* result);
   bool ParseRandomDistribution(RandomDistribution* result);
-  bool ParsePrecision(PrecisionConfigProto::Precision* result);
+  bool ParsePrecision(PrecisionConfig::Precision* result);
   bool ParseInt64(tensorflow::int64* result);
   bool ParseDouble(double* result);
   bool ParseBool(bool* result);
@@ -366,9 +378,25 @@
     return false;
   }
 
+  absl::optional<bool> is_scheduled;
+  std::unordered_map<string, AttrConfig> attrs;
+  attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+  if (!ParseAttributes(attrs)) {
+    return false;
+  }
+
   module_ = absl::make_unique<HloModule>(name, config_);
 
-  return ParseComputations();
+  if (!ParseComputations()) {
+    return false;
+  }
+
+  if (is_scheduled.has_value() && *is_scheduled) {
+    TF_CHECK_OK(
+        module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+  }
+
+  return true;
 }
 
 // computations ::= (computation)+
@@ -530,10 +558,6 @@
   attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
                              &backend_config};
 
-  optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
-  attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
-                                &operand_precision};
-
   HloInstruction* instruction;
   switch (opcode) {
     case HloOpcode::kParameter: {
@@ -550,7 +574,7 @@
       break;
     }
     case HloOpcode::kConstant: {
-      std::unique_ptr<Literal> literal;
+      Literal literal;
       if (!ParseToken(TokKind::kLparen,
                       "expects '(' before constant literal") ||
           !ParseLiteral(&literal, shape) ||
@@ -913,6 +937,9 @@
                              AttrTy::kConvolutionDimensionNumbers, &dnums};
       attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
                                       &feature_group_count};
+      optional<std::vector<PrecisionConfig::Precision>> operand_precision;
+      attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+                                    &operand_precision};
       if (!ParseOperands(&operands, /*expected_size=*/2) ||
           !ParseAttributes(attrs)) {
         return false;
@@ -923,9 +950,17 @@
       if (!feature_group_count) {
         feature_group_count = 1;
       }
+      PrecisionConfig precision_config;
+      if (operand_precision) {
+        *precision_config.mutable_operand_precision() = {
+            operand_precision->begin(), operand_precision->end()};
+      } else {
+        precision_config.mutable_operand_precision()->Resize(
+            operands.size(), PrecisionConfig::DEFAULT);
+      }
       instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
-          shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums,
-          feature_group_count.value()));
+          shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
+          feature_group_count.value(), *window, *dnums, precision_config));
       break;
     }
     case HloOpcode::kFft: {
@@ -1241,11 +1276,14 @@
       optional<string> custom_call_target;
       optional<Window> window;
       optional<ConvolutionDimensionNumbers> dnums;
+      optional<int64> feature_group_count;
       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
                                      &custom_call_target};
       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
       attrs["dim_labels"] = {/*required=*/false,
                              AttrTy::kConvolutionDimensionNumbers, &dnums};
+      attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+                                      &feature_group_count};
       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
         return false;
       }
@@ -1257,6 +1295,9 @@
       if (dnums.has_value()) {
         instruction->set_convolution_dimension_numbers(*dnums);
       }
+      if (feature_group_count.has_value()) {
+        instruction->set_feature_group_count(*feature_group_count);
+      }
       break;
     }
     case HloOpcode::kDot: {
@@ -1272,6 +1313,9 @@
       optional<std::vector<tensorflow::int64>> rhs_batch_dims;
       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
                                  &rhs_batch_dims};
+      optional<std::vector<PrecisionConfig::Precision>> operand_precision;
+      attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+                                    &operand_precision};
 
       if (!ParseOperands(&operands, /*expected_size=*/2) ||
           !ParseAttributes(attrs)) {
@@ -1296,8 +1340,17 @@
                                                 rhs_batch_dims->end()};
       }
 
-      instruction = builder->AddInstruction(
-          HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
+      PrecisionConfig precision_config;
+      if (operand_precision) {
+        *precision_config.mutable_operand_precision() = {
+            operand_precision->begin(), operand_precision->end()};
+      } else {
+        precision_config.mutable_operand_precision()->Resize(
+            operands.size(), PrecisionConfig::DEFAULT);
+      }
+
+      instruction = builder->AddInstruction(HloInstruction::CreateDot(
+          shape, operands[0], operands[1], dnum, precision_config));
       break;
     }
     case HloOpcode::kGather: {
@@ -1414,12 +1467,6 @@
   if (backend_config) {
     instruction->set_raw_backend_config_string(std::move(*backend_config));
   }
-  if (operand_precision) {
-    PrecisionConfigProto precision_config;
-    *precision_config.mutable_operand_precision() = {operand_precision->begin(),
-                                                     operand_precision->end()};
-    instruction->set_precision_config(precision_config);
-  }
   return AddInstruction(name, instruction, name_loc);
 }  // NOLINT(readability/fn_size)
 
@@ -1760,8 +1807,7 @@
 // literal
 //  ::= tuple
 //  ::= non_tuple
-bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
-                             const Shape& shape) {
+bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
   return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
                                    : ParseNonTupleLiteral(literal, shape);
 }
@@ -1771,8 +1817,7 @@
 // literal_list
 //  ::= /*empty*/
 //  ::= literal (',' literal)*
-bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
-                                  const Shape& shape) {
+bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
   if (!EatShapeAndCheckCompatible(shape)) {
     return TokenError(StrCat("expects tuple constant in shape ",
                              ShapeUtil::HumanString(shape)));
@@ -1780,8 +1825,7 @@
   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
     return false;
   }
-  std::vector<std::unique_ptr<Literal>> elements(
-      ShapeUtil::TupleElementCount(shape));
+  std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
 
   if (lexer_.GetKind() == TokKind::kRparen) {
     // empty
@@ -1807,8 +1851,7 @@
 //   ::= rank01
 //   ::= rank2345
 // rank2345 ::= shape sparse_or_nested_array
-bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
-                                     const Shape& shape) {
+bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
   if (LayoutUtil::IsSparseArray(shape)) {
     return ParseSparseLiteral(literal, shape);
   }
@@ -1817,8 +1860,7 @@
   return ParseDenseLiteral(literal, shape);
 }
 
-bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
-                                  const Shape& shape) {
+bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
   const tensorflow::int64 rank = ShapeUtil::Rank(shape);
   if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
     return false;
@@ -1912,7 +1954,7 @@
           // TODO(congliu): bool type literals with rank >= 1 are actually
           // printed in a compact form instead of "true" or "false". Fix that.
           if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
-                                 linear_index++, literal->get())) {
+                                 linear_index++, literal)) {
             return false;
           }
           lexer_.Lex();
@@ -1923,7 +1965,7 @@
             return Error(loc, StrCat("expects integer for primitive type: ",
                                      PrimitiveType_Name(shape.element_type())));
           }
-          if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+          if (!SetValueInLiteral(value, linear_index++, literal)) {
             return false;
           }
         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
@@ -1934,7 +1976,7 @@
                 loc, StrCat("expect floating point value for primitive type: ",
                             PrimitiveType_Name(shape.element_type())));
           }
-          if (!SetValueInLiteral(value, linear_index++, literal->get())) {
+          if (!SetValueInLiteral(value, linear_index++, literal)) {
             return false;
           }
         } else {
@@ -1946,12 +1988,11 @@
     }  // end of switch
   } while (nest_level > 0);
 
-  *literal = (*literal)->Relayout(shape.layout());
+  *literal = literal->Relayout(shape.layout());
   return true;
 }
 
-bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
-                                   const Shape& shape) {
+bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
   if (!EatShapeAndCheckCompatible(shape)) {
     return false;
   }
@@ -1991,13 +2032,12 @@
 }
 
 template <typename LiteralNativeT>
-bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
-                                         const Shape& shape) {
+bool HloParser::ParseSparseLiteralHelper(Literal* literal, const Shape& shape) {
   std::vector<tensorflow::int64> index;
 
   tensorflow::int64 rank = ShapeUtil::Rank(shape);
 
-  *literal = absl::make_unique<Literal>(shape);
+  *literal = Literal(shape);
 
   if (!ParseToken(TokKind::kLbrace,
                   "expects '{' at the beginning of a sparse literal")) {
@@ -2071,7 +2111,7 @@
       return false;
     }
 
-    if ((*literal)->sparse_element_count() + 1 ==
+    if (literal->sparse_element_count() + 1 ==
         LayoutUtil::MaxSparseElements(shape.layout())) {
       return Error(
           lexer_.GetLoc(),
@@ -2079,10 +2119,10 @@
                  ShapeUtil::HumanStringWithLayout(shape)));
     }
 
-    (*literal)->AppendSparseElement(index, value);
+    literal->AppendSparseElement(index, value);
   }
 
-  (*literal)->SortSparseElements();
+  literal->SortSparseElements();
   return true;
 }
 
@@ -2397,11 +2437,11 @@
         return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
       }
       case AttrTy::kPrecisionList: {
-        std::vector<PrecisionConfigProto::Precision> result;
+        std::vector<PrecisionConfig::Precision> result;
         if (!ParsePrecisionList(&result)) {
           return false;
         }
-        static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>(
+        static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
             attr_out_ptr)
             ->emplace(result);
         return true;
@@ -2685,9 +2725,9 @@
 //   ::= /*empty*/
 //   ::= precision_val (delim precision_val)*
 bool HloParser::ParsePrecisionList(
-    std::vector<PrecisionConfigProto::Precision>* result) {
+    std::vector<PrecisionConfig::Precision>* result) {
   auto parse_and_add_item = [&]() {
-    PrecisionConfigProto::Precision item;
+    PrecisionConfig::Precision item;
     if (!ParsePrecision(&item)) {
       return false;
     }
@@ -3019,7 +3059,7 @@
   return true;
 }
 
-bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) {
+bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
   VLOG(1) << "ParsePrecision";
   if (lexer_.GetKind() != TokKind::kIdent) {
     return TokenError("expects random distribution");
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 7597894..cca50fa 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -19,6 +19,8 @@
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/string_view.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
 #include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -382,7 +384,7 @@
   %input = f32[1,2,1]{2,1,0} parameter(0)
   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
   %filter = f32[1,1,1]{2,1,0} parameter(1)
-  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default}
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
 }
 
 )"
@@ -395,7 +397,7 @@
 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
   %input = f32[1,2]{1,0} parameter(0)
   %filter = f32[1,1]{1,0} parameter(1)
-  ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1
+  ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
 }
 
 )"
@@ -408,7 +410,7 @@
 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
   %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
   %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
-  ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1
+  ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
 }
 
 )"
@@ -1121,18 +1123,31 @@
 
 )"
 },
-// custom-call with window and dim_labels
+// custom-call with window, dim_labels and feature_group_count
 {
-"CustomCallWithWindowAndDimLabels",
-R"(HloModule CustomCallWithWindowAndDimLabels
+"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
+R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
 
 ENTRY Computation {
-  ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+  ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
+}
+
+)"
+    },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+  keys = f32[1024]{0} parameter(0)
+  values = s32[1024]{0} parameter(1)
+  ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
 }
 
 )"
 }
-  });
+});
   // clang-format on
 }
 
@@ -1775,5 +1790,107 @@
       ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
 }
 
+TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
+  const string text =
+      R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
+  TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+  const HloComputation* computation = module->entry_computation();
+  ASSERT_NE(computation, nullptr);
+  EXPECT_THAT(computation->root_instruction(),
+              op::Convolution(op::Parameter(0), op::Parameter(1)));
+  auto* convolution =
+      Cast<HloConvolutionInstruction>(computation->root_instruction());
+  EXPECT_EQ(convolution->feature_group_count(), 1);
+}
+
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+  const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %x = f32[2,4]{1,0} parameter(1)
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  %y = f32[2,4]{1,0} parameter(2)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+  const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %x = f32[2,4]{1,0} parameter(1)
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  %y = f32[2,4]{1,0} parameter(2)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+  const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %x = f32[2,4]{1,0} parameter(1)
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  %y = f32[2,4]{1,0} parameter(2)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_TRUE(module->has_schedule());
+  TF_ASSERT_OK(module->schedule().Verify());
+  EXPECT_EQ(module->schedule().sequences().size(), 1);
+  ASSERT_TRUE(
+      module->schedule().is_computation_scheduled(module->entry_computation()));
+  EXPECT_THAT(
+      module->schedule().sequence(module->entry_computation()).instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+                             op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+  // As above but in with a different schedule order.
+  const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %x = f32[2,4]{1,0} parameter(1)
+  %y = f32[2,4]{1,0} parameter(2)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+  ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(text));
+  ASSERT_TRUE(module->has_schedule());
+  TF_ASSERT_OK(module->schedule().Verify());
+  EXPECT_EQ(module->schedule().sequences().size(), 1);
+  ASSERT_TRUE(
+      module->schedule().is_computation_scheduled(module->entry_computation()));
+  EXPECT_THAT(
+      module->schedule().sequence(module->entry_computation()).instructions(),
+      ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+                             op::Broadcast(), op::Multiply(), op::Add()));
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679..b9c0b0c 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@
 
 HloProto MakeHloProto(const HloModule& module,
                       const BufferAssignment& assignment) {
-  HloOrderingProto proto_ordering =
-      assignment.liveness().hlo_ordering().ToProto();
   BufferAssignmentProto proto_assignment = assignment.ToProto();
   HloProto proto = MakeHloProto(module);
-  proto.mutable_hlo_ordering()->Swap(&proto_ordering);
   proto.mutable_buffer_assignment()->Swap(&proto_assignment);
   return proto;
 }
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 585c959..d9848ce 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -20,13 +20,13 @@
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 
 namespace xla {
 
 namespace {
 
-class HloReachabilityTest : public HloTestBase {};
+class HloReachabilityTest : public HloVerifiedTestBase {};
 
 TEST_F(HloReachabilityTest, Reachability) {
   // Construct and test a reachability graph of the following form:
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c962992..bd6dd79 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -27,15 +27,14 @@
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/primitive_util.h"
 #include "tensorflow/compiler/xla/service/buffer_value.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_dce.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
 #include "tensorflow/compiler/xla/service/logical_buffer.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
@@ -962,8 +961,7 @@
 }
 
 StatusOr<bool> HloRematerialization::RematerializeComputation(
-    HloComputation* computation,
-    SequentialHloOrdering::HloModuleSequence* sequence,
+    HloComputation* computation, HloSchedule* schedule,
     int64 memory_limit_bytes) {
   VLOG(1) << "Rematerializing computation " << computation->name()
           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
@@ -971,7 +969,8 @@
           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
   CHECK(!ContainsKey(rematerialized_computations_, computation));
 
-  InstructionList instruction_list(sequence->at(computation));
+  InstructionList instruction_list(
+      schedule->sequence(computation).instructions());
   MemoryUsageTracker memory_tracker(computation, size_function_,
                                     *points_to_analysis_, instruction_list);
   bool changed = false;
@@ -1145,7 +1144,7 @@
               0, memory_limit_bytes - memory_tracker.memory_usage());
           TF_ASSIGN_OR_RETURN(
               bool subcomputation_changed,
-              RematerializeComputation(called_computation, sequence,
+              RematerializeComputation(called_computation, schedule,
                                        subcomputation_memory_limit_bytes));
           changed |= subcomputation_changed;
         }
@@ -1179,12 +1178,12 @@
   computation_peak_memory_.at(computation) = peak_memory;
 
   // Update order to include rematerialized instructions.
-  auto& dst = sequence->at(computation);
-  dst.clear();
+  HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
+  sequence.clear();
   for (auto* item = instruction_list.first(); item != nullptr;
        item = instruction_list.next(item)) {
     const HloInstruction* instruction = item->instruction;
-    dst.push_back(instruction);
+    sequence.push_back(instruction);
   }
   rematerialized_computations_.insert(computation);
 
@@ -1194,59 +1193,12 @@
   return changed;
 }
 
-StatusOr<bool> HloRematerialization::Run(
-    HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
-    int64 memory_limit_bytes, RematerializationSizes* sizes,
-    CopyInsertion* copy_insertion) {
-  // The sequence is constructed entirely by this method.
-  TF_RET_CHECK(sequence->empty());
-
+StatusOr<bool> HloRematerialization::Run(HloModule* module) {
   VLOG(1) << "HloRematerialization() with memory limit of "
-          << HumanReadableNumBytes(memory_limit_bytes);
+          << HumanReadableNumBytes(memory_limit_bytes_);
   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
 
-  // Create initial sequence of HLO instructions.
-  TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
-                                     *module,
-                                     [this](const BufferValue& buffer) {
-                                       return size_function_(buffer.shape());
-                                     },
-                                     scheduler_algorithm_));
-  if (copy_insertion) {
-    // We run a separate pass of copy elision here because the sequential
-    // ordering from the HLO schedule allows for more copies to be eliminated.
-    // TODO(b/80249101): Instead of a separate copy elision pass, use the
-    // ordering from the HLO schedule directly for copy insertion.
-
-    // First create a copy of the schedule which contains HloInstruction unique
-    // ids instead of HloInstruction*. This is necessary for updating the
-    // schedule below.
-    // TODO(b/113175018): Remove this when the HLO schedule is self-contained
-    // and can update itself.
-    tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-        id_sequence = ComputeIdSchedule(*sequence);
-
-    SequentialHloOrdering ordering(module, *sequence);
-    TF_RETURN_IF_ERROR(
-        copy_insertion->RemoveUnnecessaryCopies(ordering, module));
-
-    // RemoveUnnecessaryCopies only considers interference when determining
-    // whether it is legal to remove a copy. However, copies in the graph may be
-    // necessary for other reason such as preventing a constant from being live
-    // out of the graph. So run AddSpecialCaseCopies to re-insert these copies.
-    // TODO(b/80249101): Break copy insertion into several passes and run each
-    // one once in the regular HLO pipeline.
-    TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module));
-
-    // The passes above can add and remove copies, update the schedule to
-    // account for these transformations. Newly added instructions will be
-    // placed ASAP in the schedule.
-    TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence));
-
-    TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference(
-        SequentialHloOrdering(module, *sequence), module));
-  }
-
+  TF_RET_CHECK(module->has_schedule());
   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
 
   // Adjust memory limit to account for the output of the entry
@@ -1262,7 +1214,7 @@
       });
 
   const int64 adjusted_memory_limit_bytes =
-      memory_limit_bytes - module_output_size;
+      memory_limit_bytes_ - module_output_size;
   VLOG(1) << "Adjusted memory limit accounting for output ("
           << HumanReadableNumBytes(module_output_size)
           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
@@ -1271,12 +1223,14 @@
   // sequential context.
   call_graph_ = CallGraph::Build(module);
   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
-      [this, sequence](const CallGraphNode& node) -> Status {
+      [this, module](const CallGraphNode& node) -> Status {
         if (node.context() == CallContext::kSequential) {
           TF_ASSIGN_OR_RETURN(
               computation_peak_memory_[node.computation()],
               ComputePeakMemory(node.computation(),
-                                sequence->at(node.computation())));
+                                module->schedule()
+                                    .sequence(node.computation())
+                                    .instructions()));
         }
         return Status::OK();
       },
@@ -1294,9 +1248,10 @@
 
   // Subcomputations called by the entry computation will also be
   // rematerialized.
-  TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation(
-                                        module->entry_computation(), sequence,
-                                        adjusted_memory_limit_bytes));
+  TF_ASSIGN_OR_RETURN(
+      bool changed,
+      RematerializeComputation(module->entry_computation(), &module->schedule(),
+                               adjusted_memory_limit_bytes));
 
   // Rematerialization can introduce dead code. This occurs if all uses of an
   // instruction are replaced with rematerializations of the instruction.
@@ -1305,30 +1260,7 @@
 
   // After DCE, the module sequence may include instructions which no longer
   // exist.
-  for (const auto* computation : module->MakeNonfusionComputations()) {
-    if (sequence->at(computation).size() != computation->instruction_count()) {
-      // A size mismatch between the computation instruction count and the size
-      // of the ordering of instructions can only be caused by DCE. Rebuild the
-      // order by removing the deleted instructions from the order.
-      tensorflow::gtl::FlatSet<const HloInstruction*> instruction_set;
-      for (const auto& instruction : computation->instructions()) {
-        instruction_set.insert(instruction);
-      }
-      // Move the old order into a temporary vector, then build new order
-      // inplace.
-      std::vector<const HloInstruction*>& order = sequence->at(computation);
-      std::vector<const HloInstruction*> old_order;
-      using std::swap;
-      swap(order, old_order);
-      std::copy_if(old_order.begin(), old_order.end(),
-                   std::back_inserter(order),
-                   [&instruction_set](const HloInstruction* instruction) {
-                     return ContainsKey(instruction_set, instruction);
-                   });
-      TF_RET_CHECK(sequence->at(computation).size() ==
-                   computation->instruction_count());
-    }
-  }
+  TF_RETURN_IF_ERROR(module->schedule().Update());
   VLOG(1) << "Rematerialized " << instructions_rematerialized_
           << " instructions in module " << module->name() << "; "
           << net_instructions_added_ << " net instructions added";
@@ -1345,33 +1277,22 @@
           << HumanReadableNumBytes(reduced_peak_memory) << " ("
           << reduced_peak_memory << " bytes)";
 
-  if (sizes != nullptr) {
-    sizes->before_bytes = before_peak_memory;
-    sizes->after_bytes = current_peak_memory;
+  if (sizes_ != nullptr) {
+    sizes_->before_bytes = before_peak_memory;
+    sizes_->after_bytes = current_peak_memory;
   }
 
   XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
 
-  if (current_peak_memory > memory_limit_bytes) {
+  if (current_peak_memory > memory_limit_bytes_) {
     LOG(WARNING) << absl::StrFormat(
         "Can't reduce memory use below %s (%d bytes) by rematerialization; "
         "only reduced to %s (%d bytes)",
-        HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes,
+        HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
         HumanReadableNumBytes(current_peak_memory), current_peak_memory);
   }
 
   return changed;
 }
 
-/* static */ StatusOr<bool> HloRematerialization::RematerializeAndSchedule(
-    const HloRematerialization::ShapeSizeFunction& size_function,
-    int64 memory_limit_bytes, HloModule* hlo_module,
-    MemorySchedulerAlgorithm scheduler_algorithm,
-    SequentialHloOrdering::HloModuleSequence* sequence,
-    RematerializationSizes* sizes, CopyInsertion* copy_insertion) {
-  HloRematerialization remat(scheduler_algorithm, size_function);
-  return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
-                   copy_insertion);
-}
-
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ec0043..e2aaf18 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -17,16 +17,23 @@
 
 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
 #include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/copy_insertion.h"
 #include "tensorflow/compiler/xla/service/hlo_computation.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
 
 namespace xla {
 
-class HloRematerialization {
+// HLO pass which rematerializes instructions to reduce peak memory use, where
+// memory use is defined as the total size of all live HLO instruction
+// values. Parameters and constants are included in memory use estimates.
+//
+// CSE will undo the effects of this optimization and should not be run after
+// this pass. In general, this pass should be run very late, immediately before
+// code generation.
+class HloRematerialization : public HloPassInterface {
  public:
   using ShapeSizeFunction = std::function<int64(const Shape&)>;
 
@@ -37,10 +44,7 @@
     int64 after_bytes;
   };
 
-  // Rematerialize HLO instructions in the given module to reduce peak memory
-  // use below memory_limit_bytes where memory use is defined as the total size
-  // of all live HLO instruction values. Parameters and constants are included
-  // in memory use estimates. Method parameters:
+  // Constructor parameters:
   //
   //   size_function: Function which returns the size in bytes of the top-level
   //     buffer of the given shape.
@@ -48,60 +52,34 @@
   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
   //     via rematerialization.
   //
-  //   hlo_module: HLO module to rematerialize instructions in.
-  //
-  //   sequence: Should point to an empty HloModuleSequence. Upon return
-  //     contains the HLO instruction order which was used for
-  //     rematerialization. This is the order in which HLO instructions should
-  //     be emitted to minimize memory use.
-  //
-  //   sizes: Optional outparam that indicates the peak memory usage of the HLO
-  //     module before/after rematerialization.
-  //
-  //   copy_insertion: If non-null, run copy elision after scheduling. This
-  //     pass is used to eliminate copies that were inserted by copy insertion
-  //     before HLO scheduling.
-  //
-  // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
-  // insertion is integrated with HLO scheduling.
-  //
-  // Returns whether any instructions were rematerialized. If memory use is
-  // already below the given limit then no instructions are rematerialized and
-  // false is returned.
-  //
-  // CSE will undo the effects of this optimization and should not be run after
-  // this pass. In general, this pass should be run very late immediately before
-  // code generation.
-  static StatusOr<bool> RematerializeAndSchedule(
-      const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
-      HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
-      SequentialHloOrdering::HloModuleSequence* sequence,
-      RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr);
-
- protected:
-  HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
-                       const ShapeSizeFunction& size_function)
-      : scheduler_algorithm_(scheduler_algorithm),
-        size_function_(size_function) {}
+  //   sizes: Pointer to data structure which records the peak memory usage of
+  //     the HLO module before/after rematerialization. Value are set during
+  //     Run(). Can be nullptr.
+  HloRematerialization(const ShapeSizeFunction& size_function,
+                       int64 memory_limit_bytes, RematerializationSizes* sizes)
+      : size_function_(size_function),
+        memory_limit_bytes_(memory_limit_bytes),
+        sizes_(sizes) {}
   ~HloRematerialization() {}
 
-  // Runs rematerialization on the given module. Returns whether the module was
-  // changed. memory_limit is the target maximum peak memory usage by the
-  // module. sequence should be an empty HloModuleSequence. Upon return sequence
-  // contains the memory-minimizing order in which to emit the HLO instructions.
-  StatusOr<bool> Run(HloModule* module,
-                     SequentialHloOrdering::HloModuleSequence* sequence,
-                     int64 memory_limit, RematerializationSizes* sizes,
-                     CopyInsertion* copy_insertion);
+  absl::string_view name() const override { return "rematerialization"; }
 
+  // Runs rematerialization on the given module. Returns whether the module was
+  // changed. Requires that the module has a schedule set
+  // (HloModule::has_schedule() is true) before running. Returns whether any
+  // instructions were rematerialized. If memory use is already below the limit
+  // specified in the constructor then no instructions are rematerialized and
+  // false is returned.
+  StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
   // Rematerializes instructions within the given computation. 'order' is the
   // order in which the computation's instructions will be emitted in the
   // backend. Rematerialized instructions will be added to the HLO computation
   // and inserted into 'order'.
-  StatusOr<bool> RematerializeComputation(
-      HloComputation* computation,
-      SequentialHloOrdering::HloModuleSequence* sequence,
-      int64 computation_memory_limit);
+  StatusOr<bool> RematerializeComputation(HloComputation* computation,
+                                          HloSchedule* schedule,
+                                          int64 memory_limit_bytes);
 
   // Computes and returns the peak memory used by the given computation. The
   // peak memory is the maximum total size of all live HLO instruction values at
@@ -122,6 +100,14 @@
   // Function which computes the size of the top-level buffer of a shape.
   const ShapeSizeFunction size_function_;
 
+  // The threshold number of bytes to reduce memory use to via
+  // rematerialization.
+  const int64 memory_limit_bytes_;
+
+  // Pointer to data structure which records the peak memory usage of the HLO
+  // module before/after rematerialization
+  RematerializationSizes* sizes_;
+
   // Call graph of the hlo_module.
   std::unique_ptr<CallGraph> call_graph_;
 
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index ac8c97d..f7e82fb 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -24,7 +24,7 @@
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
 #include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
@@ -36,7 +36,7 @@
 
 using ::testing::_;
 
-class HloRematerializationTest : public HloTestBase {
+class HloRematerializationTest : public HloVerifiedTestBase {
  protected:
   // Creates and returns a computation which can benefit from
   // rematerialization. The computation looks like:
@@ -141,13 +141,16 @@
     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
   }
 
-  StatusOr<bool> RunHloRematerialization(
-      int64 memory_limit_bytes, HloModule* module,
-      SequentialHloOrdering::HloModuleSequence* sequence) {
+  StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
+                                         HloModule* module) {
     TF_EXPECT_OK(verifier().Run(module).status());
-    return HloRematerialization::RematerializeAndSchedule(
-        ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
-        sequence, /*sizes=*/nullptr);
+    HloMemoryScheduler scheduler(
+        [](const BufferValue& buffer) { return ByteSizeOf(buffer.shape()); },
+        DefaultMemoryScheduler);
+    TF_EXPECT_OK(scheduler.Run(module).status());
+    HloRematerialization remat(ByteSizeOf, memory_limit_bytes,
+                               /*sizes=*/nullptr);
+    return remat.Run(module);
   }
 
   // Various shapes used in the canned computations.
@@ -170,12 +173,11 @@
   const HloInstruction* concat = slice->operand(0);
   const HloInstruction* bcast = concat->operand(0);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
   // Computation requires 16KB without rematerialization, but uses only 12KB
   // with rematerialization so pick a memory limit between these values (14KB).
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/14 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/14 * 1024, module));
   EXPECT_TRUE(changed);
 
   // Root should not have changed.
@@ -187,9 +189,13 @@
 
   // The rematerialized broadcast should be immediate before the concat in the
   // sequence.
-  EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2],
+  EXPECT_EQ(module->schedule()
+                .sequence(computation)
+                .instructions()[computation->instruction_count() - 2],
             concat);
-  EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3],
+  EXPECT_EQ(module->schedule()
+                .sequence(computation)
+                .instructions()[computation->instruction_count() - 3],
             remat_bcast);
 }
 
@@ -203,10 +209,9 @@
 
   EXPECT_EQ(computation->instruction_count(), 8);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/20 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/20 * 1024, module));
 
   // No instructions should have been materialized.
   EXPECT_FALSE(changed);
@@ -242,10 +247,9 @@
   // The body computation uses 16KB and the entry computation uses 2KB at the
   // while so the peak memory use of the module is 18KB. Set the memory limit a
   // bit lower (17KB) to force rematerialization of the entry computation.
-  SequentialHloOrdering::HloModuleSequence sequence;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/17 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/17 * 1024, module));
   EXPECT_TRUE(changed);
 
   // Only the entry computation should have a rematerialized instruction added.
@@ -276,10 +280,9 @@
   EXPECT_EQ(entry_computation->instruction_count(), 7);
   EXPECT_EQ(body_computation->instruction_count(), 8);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/15 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/15 * 1024, module));
   EXPECT_TRUE(changed);
 
   // Both computations should have rematerialized instructions added.
@@ -316,10 +319,9 @@
 
   // If all computations are maximally rematerialized then peak memory usage is
   // ~12K so pick something slightly larger.
-  SequentialHloOrdering::HloModuleSequence sequence;
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/13 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/13 * 1024, module));
   EXPECT_TRUE(changed);
 
   // All computations should have rematerialized instructions added.
@@ -382,14 +384,13 @@
   ASSERT_EQ(count_rngs(entry_computation), 1);
   const int64 original_instruction_count =
       entry_computation->instruction_count();
-  SequentialHloOrdering::HloModuleSequence sequence;
   // Pick a memory limit some where between 24KB (initial peak memory including
   // parameter and output) and 20KB (peak memory possible with
   // rematerialization).
   TF_ASSERT_OK_AND_ASSIGN(
-      bool changed, RunHloRematerialization(
-                        /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
-                        module.get(), &sequence));
+      bool changed,
+      RunHloRematerialization(
+          /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), module));
   EXPECT_TRUE(changed);
   // The rng should not have been rematerialized.
   EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,13 +477,12 @@
   EXPECT_EQ(add_3->operand(0), bcast);
   EXPECT_EQ(add_4->operand(0), bcast);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
   // Pick a memory limit some where between 24KB (initial peak memory including
   // parameter and output) and 20KB (peak memory possible with
   // rematerialization).
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/22 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/22 * 1024, module));
   EXPECT_TRUE(changed);
 
   // The broadcast should have been rematerialized 3 times.
@@ -571,13 +571,12 @@
 
   EXPECT_EQ(entry_computation->instruction_count(), 8);
 
-  SequentialHloOrdering::HloModuleSequence sequence;
   // Pick a memory limit some where between 24KB (initial peak memory including
   // parameter and output) and 20KB (peak memory possible with
   // rematerialization).
-  TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
-                                            /*memory_limit_bytes=*/22 * 1024,
-                                            module.get(), &sequence));
+  TF_ASSERT_OK_AND_ASSIGN(bool changed,
+                          RunHloRematerialization(
+                              /*memory_limit_bytes=*/22 * 1024, module));
   // Rematerialization should only occur if the rematerializable instruction has
   // no indirect uses.
   if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 66ac1f6..fa7f216 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -118,16 +118,16 @@
 }
 
 StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
-    const absl::Span<const std::unique_ptr<Literal>> literals) {
+    const absl::Span<const Literal> literals) {
   std::vector<const Literal*> literal_pointers;
   literal_pointers.reserve(literals.size());
   for (const auto& literal : literals) {
-    literal_pointers.push_back(literal.get());
+    literal_pointers.push_back(&literal);
   }
   return TransferLiteralsToDevice(literal_pointers);
 }
 
-StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
+StatusOr<Literal> HloRunner::TransferLiteralFromDevice(
     const ShapedBuffer& buffer) {
   TF_ASSIGN_OR_RETURN(
       auto stream, backend().BorrowStream(backend().default_stream_executor()));
@@ -135,7 +135,7 @@
                                                                  buffer);
 }
 
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
+StatusOr<Literal> HloRunner::Execute(
     std::unique_ptr<HloModule> module,
     const absl::Span<const Literal* const> arguments, bool run_hlo_passes,
     ExecutionProfile* profile) {
@@ -150,15 +150,15 @@
   return TransferLiteralFromDevice(result);
 }
 
-StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
-    std::unique_ptr<HloModule> module,
-    const absl::Span<const std::unique_ptr<Literal>> arguments,
-    bool run_hlo_passes, ExecutionProfile* profile) {
+StatusOr<Literal> HloRunner::Execute(std::unique_ptr<HloModule> module,
+                                     const absl::Span<const Literal> arguments,
+                                     bool run_hlo_passes,
+                                     ExecutionProfile* profile) {
   // Construct a vector of plain pointers for the arguments.
   std::vector<const Literal*> argument_pointers;
   argument_pointers.reserve(arguments.size());
   for (const auto& argument : arguments) {
-    argument_pointers.push_back(argument.get());
+    argument_pointers.push_back(&argument);
   }
   return Execute(
       /*module=*/std::move(module),
@@ -204,7 +204,7 @@
       /*profile=*/profile);
 }
 
-StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
+StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
     std::unique_ptr<HloModule> module,
     const ReplicatedExecuteOptions& options) {
   TF_ASSIGN_OR_RETURN(
@@ -290,9 +290,9 @@
         VLOG(1) << "Starting outfeed on device " << device;
         for (int64 step = 1;
              options.infeed_steps < 0 || step <= options.infeed_steps; ++step) {
-          auto literal = absl::make_unique<Literal>();
+          Literal literal;
           TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
-              executor, options.outfeed_shape, literal.get()));
+              executor, options.outfeed_shape, &literal));
           if (options.outfeed_values != nullptr) {
             options.outfeed_values->push_back(std::move(literal));
           }
@@ -310,10 +310,10 @@
                                                    argument_buffer_slices));
   LOG(INFO) << "Replicated execution terminated";
 
-  std::vector<std::unique_ptr<Literal>> exec_results;
+  std::vector<Literal> exec_results;
   for (int64 i = 0; i < options.num_replicas; ++i) {
     TF_RETURN_IF_ERROR(streams[i]->BlockHostUntilDone());
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+    TF_ASSIGN_OR_RETURN(Literal literal,
                         backend().transfer_manager()->TransferLiteralFromDevice(
                             streams[i].get(), results[i]));
     exec_results.push_back(std::move(literal));
diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h
index 76d8b92..2e934bf 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.h
+++ b/tensorflow/compiler/xla/service/hlo_runner.h
@@ -72,7 +72,7 @@
 
     // A pointer to a vector where the outfeed values will be stored. If
     // nullptr, the values will be read and discarded.
-    std::vector<std::unique_ptr<Literal>>* outfeed_values = nullptr;
+    std::vector<Literal>* outfeed_values = nullptr;
 
     // Whether the HLO passes should be run on the input module. Usually
     // saved modules are coming from after the HLO pass pipeline, so triggering
@@ -106,24 +106,23 @@
   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
       const absl::Span<const Literal* const> literals);
   StatusOr<std::vector<ScopedShapedBuffer>> TransferLiteralsToDevice(
-      const absl::Span<const std::unique_ptr<Literal>> literals);
-  StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
-      const ShapedBuffer& buffer);
+      const absl::Span<const Literal> literals);
+  StatusOr<Literal> TransferLiteralFromDevice(const ShapedBuffer& buffer);
 
   // Executes the given module with given literals as input and returns the
   // result as a Literal.
   //
   // If run_hlo_passes is false, the module will be executed without Hlo
   // optimization.
-  StatusOr<std::unique_ptr<Literal>> Execute(
-      std::unique_ptr<HloModule> module,
-      const absl::Span<const Literal* const> arguments,
-      bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+  StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+                            const absl::Span<const Literal* const> arguments,
+                            bool run_hlo_passes = true,
+                            ExecutionProfile* profile = nullptr);
 
-  StatusOr<std::unique_ptr<Literal>> Execute(
-      std::unique_ptr<HloModule> module,
-      const absl::Span<const std::unique_ptr<Literal>> arguments,
-      bool run_hlo_passes = true, ExecutionProfile* profile = nullptr);
+  StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+                            const absl::Span<const Literal> arguments,
+                            bool run_hlo_passes = true,
+                            ExecutionProfile* profile = nullptr);
 
   // As Execute(), but accepts and returns device buffers instead of host
   // buffers.
@@ -140,7 +139,7 @@
   // Executes a given HLO module into a set of replicas, and returns a map
   // with the replica number as key, and the corresponding returned literal as
   // value.
-  StatusOr<std::vector<std::unique_ptr<Literal>>> ExecuteReplicated(
+  StatusOr<std::vector<Literal>> ExecuteReplicated(
       std::unique_ptr<HloModule> module,
       const ReplicatedExecuteOptions& options);
 
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
new file mode 100644
index 0000000..3fc5dbe
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -0,0 +1,343 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <queue>
+#include <vector>
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+    const HloModule* module, const HloScheduleProto& proto) {
+  tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+  for (const HloComputation* computation : module->computations()) {
+    id_to_computation[computation->unique_id()] = computation;
+  }
+
+  HloSchedule schedule(module);
+  for (const auto& id_sequence : proto.sequences()) {
+    int64 computation_id = id_sequence.first;
+
+    auto comp_it = id_to_computation.find(computation_id);
+    TF_RET_CHECK(comp_it != id_to_computation.end())
+        << "No computation exists in HLO module with id " << computation_id;
+    const HloComputation* computation = comp_it->second;
+
+    tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+    for (const HloInstruction* instruction : computation->instructions()) {
+      id_to_instruction[instruction->unique_id()] = instruction;
+    }
+
+    HloInstructionSequence& sequence =
+        schedule.GetOrCreateSequence(computation);
+    for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+      auto instr_it = id_to_instruction.find(instruction_id);
+      TF_RET_CHECK(instr_it != id_to_instruction.end())
+          << "No instruction exists in HLO computation " << computation->name()
+          << " with id " << instruction_id;
+      sequence.push_back(instr_it->second);
+    }
+  }
+  TF_RETURN_IF_ERROR(schedule.Verify());
+  return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+  TF_RETURN_IF_ERROR(Verify());
+  HloScheduleProto proto;
+  for (const auto& id_sequence : sequences_) {
+    int64 computation_id = id_sequence.first;
+    const HloInstructionSequence& sequence = id_sequence.second;
+    HloScheduleProto::InstructionSequence& proto_sequence =
+        (*proto.mutable_sequences())[computation_id];
+    proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+    for (const int64 id : sequence.ids()) {
+      proto_sequence.add_instruction_ids(id);
+    }
+  }
+  return std::move(proto);
+}
+
+void HloSchedule::set_sequence(
+    const HloComputation* computation,
+    absl::Span<const HloInstruction* const> sequence) {
+  set_sequence(computation, HloInstructionSequence(sequence));
+}
+
+void HloSchedule::set_sequence(const HloComputation* computation,
+                               HloInstructionSequence sequence) {
+  CHECK(computation->parent() == module_);
+  sequences_[computation->unique_id()] = std::move(sequence);
+}
+
+HloInstructionSequence& HloSchedule::GetOrCreateSequence(
+    const HloComputation* computation) {
+  auto it = sequences_.find(computation->unique_id());
+  if (it == sequences_.end()) {
+    // No sequence found for computation. Create and return an empty one.
+    CHECK(computation->parent() == module_);
+    return sequences_[computation->unique_id()];
+  } else {
+    return it->second;
+  }
+}
+
+const HloInstructionSequence& HloSchedule::sequence(
+    const HloComputation* computation) const {
+  return sequences_.at(computation->unique_id());
+}
+
+Status HloSchedule::UpdateComputationSchedule(
+    const HloComputation* computation) {
+  // Map from unique ID to HloInstruction pointer for instructions in the
+  // computation.
+  tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+  for (const HloInstruction* instruction : computation->instructions()) {
+    InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
+  }
+
+  // Set of all HloInstructions in the schedule.
+  tensorflow::gtl::FlatSet<int> ids_in_schedule;
+  for (int id : sequences_.at(computation->unique_id()).ids()) {
+    InsertOrDie(&ids_in_schedule, id);
+  }
+
+  // Map from HloInstruction X to newly added instructions (instruction is in
+  // computation, but not in schedule) which use X. If an instruction is not in
+  // the map, then it has no users which are newly added instructions.
+  tensorflow::gtl::FlatMap<const HloInstruction*,
+                           std::vector<const HloInstruction*>>
+      new_instruction_uses;
+
+  // For each newly added instruction, this is the count of the instruction's
+  // operands that have not yet been scheduled. When this value reaches zero,
+  // then the instruction may be placed in the schedule.
+  tensorflow::gtl::FlatMap<const HloInstruction*, int>
+      unscheduled_operand_count;
+
+  // Create a worklist of newly added instructions which are ready to be added
+  // to the schedule. Initialize worklist with those that have zero operands.
+  std::queue<const HloInstruction*> worklist;
+
+  for (const HloInstruction* instruction : computation->instructions()) {
+    if (ids_in_schedule.count(instruction->unique_id()) == 0) {
+      // This is a newly added instruction which is not in the schedule.
+      if (instruction->operands().empty()) {
+        worklist.push(instruction);
+      } else {
+        for (const HloInstruction* operand : instruction->operands()) {
+          new_instruction_uses[operand].push_back(instruction);
+        }
+        unscheduled_operand_count[instruction] = instruction->operand_count();
+      }
+    }
+  }
+
+  // Update the schedule with the newly added instructions, and remove any
+  // instructions no longer in the graph.
+  HloInstructionSequence new_sequence;
+
+  // Lambda which schedules all instructions on the worklist.
+  auto schedule_worklist = [&]() {
+    while (!worklist.empty()) {
+      const HloInstruction* instruction = worklist.front();
+      worklist.pop();
+      new_sequence.push_back(instruction);
+      std::vector<const HloInstruction*>* new_users =
+          tensorflow::gtl::FindOrNull(new_instruction_uses, instruction);
+      if (new_users != nullptr) {
+        // This just-scheduled instruction has users which are newly added to
+        // the module. Update the number of unscheduled operands and push the
+        // newly added instruction to the worklist if it is ready to
+        // schedule.
+        for (const HloInstruction* new_user : *new_users) {
+          unscheduled_operand_count.at(new_user)--;
+          CHECK_GE(unscheduled_operand_count.at(new_user), 0);
+          if (unscheduled_operand_count.at(new_user) == 0) {
+            worklist.push(new_user);
+          }
+        }
+      }
+    }
+  };
+
+  schedule_worklist();
+  for (int id : sequences_.at(computation->unique_id()).ids()) {
+    auto it = id_to_instruction.find(id);
+    if (it == id_to_instruction.end()) {
+      // This instruction in the schedule is no longer in the module. Do not add
+      // it to the new schedule.
+      continue;
+    }
+    worklist.push(it->second);
+    schedule_worklist();
+  }
+
+  set_sequence(computation, std::move(new_sequence));
+  return Status::OK();
+}
+
+Status HloSchedule::Update() {
+  // The schedule must contain a sequence for every non-fusion computation in
+  // the module, but can have sequences for computations which no longer exist
+  // (these are removed).
+  std::vector<HloComputation*> nonfusion_computations =
+      module_->MakeNonfusionComputations();
+  for (const HloComputation* computation : nonfusion_computations) {
+    TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+        << "Computation " << computation->name() << " not in HloSchedule.";
+  }
+  if (sequences_.size() > nonfusion_computations.size()) {
+    // Schedule contains some computations which have been removed from the
+    // HloModule. Remove them from the schedule as well.
+    tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+    for (const HloComputation* computation : nonfusion_computations) {
+      nonfusion_computations_ids.insert(computation->unique_id());
+    }
+    for (auto it = sequences_.begin(); it != sequences_.end();) {
+      if (nonfusion_computations_ids.count(it->first) == 0) {
+        it = sequences_.erase(it);
+      } else {
+        it++;
+      }
+    }
+  }
+  CHECK_EQ(sequences_.size(), nonfusion_computations.size());
+
+  for (const HloComputation* computation : nonfusion_computations) {
+    TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation));
+  }
+
+  TF_RETURN_IF_ERROR(Verify());
+  return Status::OK();
+}
+
+Status HloSchedule::Verify() const {
+  VLOG(2) << "VerifySchedule()";
+  XLA_VLOG_LINES(3, module_->ToString());
+  XLA_VLOG_LINES(2, ToString());
+
+  // Verify schedule contains exactly the same set of non-fusion computations as
+  // module currently does.
+  std::vector<HloComputation*> nonfusion_computations =
+      module_->MakeNonfusionComputations();
+  TF_RET_CHECK(nonfusion_computations.size() == sequences_.size())
+      << "Schedule has " << sequences_.size() << " sequences, but module has "
+      << nonfusion_computations.size() << " non-fusion computations";
+  for (const HloComputation* computation : nonfusion_computations) {
+    TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
+        << "Computation " << computation->name()
+        << " missing from HLO schedule.";
+  }
+
+  // For each computation verify the set of instructions is the same and that
+  // each dependency and control edge is honored.
+  for (const HloComputation* computation : nonfusion_computations) {
+    tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+    int pos = 0;
+    for (const HloInstruction* instruction :
+         sequence(computation).instructions()) {
+      TF_RET_CHECK(instruction_position.insert({instruction, pos}).second)
+          << "Instruction " << instruction->name()
+          << " appears more than once in the schedule";
+      pos++;
+    }
+
+    TF_RET_CHECK(instruction_position.size() ==
+                 computation->instruction_count());
+    for (const HloInstruction* instruction : computation->instructions()) {
+      TF_RET_CHECK(instruction_position.count(instruction) == 1)
+          << "Instruction " << instruction->name() << " is not in schedule";
+    }
+
+    for (const HloInstruction* instruction : computation->instructions()) {
+      for (const HloInstruction* operand : instruction->operands()) {
+        TF_RET_CHECK(instruction_position.at(operand) <
+                     instruction_position.at(instruction))
+            << "Instruction " << instruction->name()
+            << " is not scheduled after its operand " << operand->name();
+      }
+
+      for (const HloInstruction* pred : instruction->control_predecessors()) {
+        TF_RET_CHECK(instruction_position.at(pred) <
+                     instruction_position.at(instruction))
+            << "Instruction " << instruction->name()
+            << " is not scheduled after its control predecessor "
+            << pred->name();
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
+namespace {
+
+// Returns the computation in the given module with the given unique ID. Returns
+// nullptr if no such computation exists.
+const HloComputation* IdToComputation(const HloModule* module, int64 id) {
+  for (const HloComputation* computation : module->computations()) {
+    if (computation->unique_id() == id) {
+      return computation;
+    }
+  }
+  return nullptr;
+}
+
+}  // namespace
+
+string HloSchedule::ToString() const {
+  std::vector<string> pieces;
+
+  pieces.push_back("HloSchedule");
+  for (const auto& id_sequence : sequences_) {
+    const HloComputation* computation =
+        IdToComputation(module_, id_sequence.first);
+    if (computation == nullptr) {
+      // The computation is not in the module and may have been deleted so it is
+      // not safe to dereference any HLO pointers. Just use the HLO unique ids
+      // stored in this object.
+      pieces.push_back(
+          absl::StrFormat("computation with id %d (no longer in HLO module):",
+                          id_sequence.first));
+      for (int id : id_sequence.second.ids()) {
+        pieces.push_back(absl::StrCat("  ", id));
+      }
+    } else {
+      pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
+      for (const HloInstruction* instruction :
+           id_sequence.second.instructions()) {
+        pieces.push_back(absl::StrCat("  ", instruction->name()));
+      }
+    }
+  }
+  return absl::StrJoin(pieces, "\n");
+}
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) {
+  out << schedule.ToString();
+  return out;
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
new file mode 100644
index 0000000..270fe60
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
+
+#include <vector>
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+
+class HloModule;
+
+// Class representing a sequence of HLO instructions such as the sequential
+// execution order of an HLO computation.
+class HloInstructionSequence {
+ public:
+  HloInstructionSequence() = default;
+  explicit HloInstructionSequence(
+      absl::Span<const HloInstruction* const> instructions) {
+    for (const HloInstruction* instruction : instructions) {
+      push_back(instruction);
+    }
+  }
+
+  // Adds the instruction to the end of the sequence.
+  void push_back(const HloInstruction* instruction) {
+    instruction_sequence_.push_back(instruction);
+    id_sequence_.push_back(instruction->unique_id());
+  }
+
+  // Clears the sequence of all instructions.
+  void clear() {
+    instruction_sequence_.clear();
+    id_sequence_.clear();
+  }
+
+  int64 size() const { return instruction_sequence_.size(); }
+
+  // Returns the sequence of HLO instructions.
+  const std::vector<const HloInstruction*>& instructions() const {
+    return instruction_sequence_;
+  }
+
+  // Returns the unique IDs of the instructions in the sequence (in order).
+  const std::vector<int>& ids() const { return id_sequence_; }
+
+ private:
+  // The sequence as HloInstructions.
+  std::vector<const HloInstruction*> instruction_sequence_;
+
+  // The sequence of HLO instructions, represented by their unique IDs. The
+  // sequence is stored as both HloInstructions and unique IDs because the
+  // sequence may be referenced after transformations to the HLO graph and HLO
+  // pointers can be invalidated or recycled in this process (see
+  // HloSchedule::Update).
+  std::vector<int> id_sequence_;
+};
+
+// A class representing a sequential schedule of instructions for an HLO
+// module. A complete HLO schedule contains an instruction sequence for every
+// non-fusion computation in the HLO module.
+class HloSchedule {
+ public:
+  explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+  // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+  static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+                                               const HloScheduleProto& proto);
+  StatusOr<HloScheduleProto> ToProto() const;
+
+  // Returns a reference to the sequence for the given computation.
+  const HloInstructionSequence& sequence(
+      const HloComputation* computation) const;
+
+  // Returns the sequence for the given computation. An empty sequence is
+  // created if none exists for the computation.
+  HloInstructionSequence& GetOrCreateSequence(
+      const HloComputation* computation);
+
+  // Sets the sequence for the given computation to the given sequence.
+  void set_sequence(const HloComputation* computation,
+                    absl::Span<const HloInstruction* const> sequence);
+  void set_sequence(const HloComputation* computation,
+                    HloInstructionSequence sequence);
+
+  // Returns a map from HloComputation unique ID to instruction sequence. The
+  // map contains all sequences in the schedule.
+  const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
+      const {
+    return sequences_;
+  }
+
+  // Returns true if the schedule has a sequence for the given computation.
+  bool is_computation_scheduled(const HloComputation* computation) const {
+    return sequences_.count(computation->unique_id()) == 1;
+  }
+
+  // Updates the schedule such that it is (again) a valid schedule for the
+  // module. This is used to update a schedule after the HLO module has been
+  // transformed in some way. In general, the only transformations to the module
+  // for which a schedule can be updated is the addition or removal of
+  // instructions and removal of computations. Updating the schedule after new
+  // dependencies between existing instructions in the module is not supported
+  // and may result in an error status returned.
+  //
+  // Instructions in the module which also exist in the given schedule will
+  // remain in the same order in the updated schedule. Instructions which exist
+  // in the module but not in the given schedule will be placed as early as
+  // possible in the updated schedule.
+  Status Update();
+
+  // Verifies that the given schedule is valid for the given module.
+  // Specifically, the schedule contains exactly the instructions in the
+  // non-fusion computations in the module and every dependency in the module is
+  // satisfied in the schedule.
+  Status Verify() const;
+
+  string ToString() const;
+
+  bool empty() const { return sequences_.empty(); }
+
+  const HloModule* module() const { return module_; }
+
+ private:
+  // Updates the instruction sequence for the given computation.
+  Status UpdateComputationSchedule(const HloComputation* computation);
+
+  const HloModule* module_;
+
+  // A map from computation unique ID to instruction sequence. Unique IDs are
+  // used rather than HloComputation pointers because HLO pointers are not
+  // unique across HLO transformations because pointers may be recycled.
+  tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
new file mode 100644
index 0000000..1424569
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloScheduleTest : public HloTestBase {};
+
+TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
+  // Updating the schedule of an unchanged HLO module should not affect the
+  // schedule at all.
+  const string module_str = R"(
+HloModule UpdateScheduleUnchanged
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  c = f32[] constant(42.0)
+  sum = f32[] add(a, b)
+  neg = f32[] negate(c)
+  ROOT root = f32[] multiply(sum, neg)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape());
+      }));
+  const std::vector<const HloInstruction*>& entry_schedule =
+      schedule.sequence(module->entry_computation()).instructions();
+
+  EXPECT_EQ(entry_schedule.size(), 6);
+
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(entry_schedule,
+            schedule.sequence(module->entry_computation()).instructions());
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
+  // Add some additional instructions to a module and verify the schedule can be
+  // updated.
+  const string module_str = R"(
+HloModule UpdateScheduleWithNewInstructions
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  c = f32[] constant(42.0)
+  sum = f32[] add(a, b)
+  neg = f32[] negate(c)
+  ROOT root = f32[] multiply(sum, neg)
+}
+)";
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape());
+      }));
+
+  HloComputation* entry = module->entry_computation();
+  const Shape shape = entry->root_instruction()->shape();
+  HloInstruction* constant = entry->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
+      shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
+  entry->set_root_instruction(sub);
+
+  auto in_schedule = [&](const HloInstruction* hlo) {
+    return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
+  };
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 6);
+  EXPECT_FALSE(in_schedule(constant));
+  EXPECT_FALSE(in_schedule(sub));
+
+  ASSERT_IS_NOT_OK(schedule.Verify());
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 8);
+  EXPECT_TRUE(in_schedule(constant));
+  EXPECT_TRUE(in_schedule(sub));
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
+  // Add and delete some instructions from a module and verify that the schedule
+  // can be updated successfully.
+  const string module_str = R"(
+HloModule UpdateScheduleWithAddedAndDeletedInstruction
+
+ENTRY main {
+  a = f32[] parameter(0)
+  b = f32[] parameter(1)
+  c = f32[] constant(42.0)
+  sum = f32[] add(a, b)
+  neg = f32[] negate(c)
+  ROOT root = f32[] multiply(sum, neg)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape());
+      }));
+
+  // Set the entry root to some expression containing just a parameter and a
+  // constant.
+  HloComputation* entry = module->entry_computation();
+  HloInstruction* constant = entry->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+  HloInstruction* new_root = entry->AddInstruction(
+      HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
+                                   constant, entry->parameter_instruction(0)));
+  entry->set_root_instruction(new_root);
+
+  // DCE should remove everything but the parameters and the newly added code.
+  HloDCE dce;
+  TF_ASSERT_OK(dce.Run(module.get()).status());
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 6);
+
+  ASSERT_IS_NOT_OK(schedule.Verify());
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 4);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
+  // Completely replace a module with an entirely new set of instructions and
+  // verify that the schedule can be updated successfully.
+  const string module_str = R"(
+HloModule UpdateScheduleWithCompletelyReplacedModule
+
+ENTRY main {
+  a = f32[] constant(42.0)
+  b = f32[] constant(123.0)
+  ROOT sum = f32[] add(a, b)
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape());
+      }));
+
+  // Replace the entry computation with the negation of a constant.
+  HloComputation* entry = module->entry_computation();
+  HloInstruction* constant = entry->AddInstruction(
+      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+  HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
+      constant->shape(), HloOpcode::kNegate, constant));
+  entry->set_root_instruction(new_root);
+
+  // DCE the old instructions.
+  HloDCE dce;
+  TF_ASSERT_OK(dce.Run(module.get()).status());
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 3);
+
+  ASSERT_IS_NOT_OK(schedule.Verify());
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(schedule.sequence(entry).size(), 2);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
+  // Create changes to more than one computation in an HLO module and verify
+  // that the schedule can be updated.
+  const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+  %param.1 = (s32[], token[]) parameter(0)
+  %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+  %constant.1 = s32[] constant(1)
+  %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+  %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+  %after-all = token[] after-all(token[] %get-tuple-element.2)
+  ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+  %param = (s32[], token[]) parameter(0)
+  %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+  %constant = s32[] constant(42)
+  ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+  %zero = s32[] constant(0)
+  %init_token = token[] after-all()
+  %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+  %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+  ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape(),
+                                     /*pointer_size=*/sizeof(void*));
+      }));
+
+  const HloInstruction* xla_while =
+      module->entry_computation()->root_instruction()->operand(0);
+  HloComputation* body = xla_while->while_body();
+  HloComputation* cond = xla_while->while_condition();
+
+  // Negate the root of the cond.
+  cond->set_root_instruction(cond->AddInstruction(
+      HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
+                                  HloOpcode::kNot, cond->root_instruction())));
+
+  // Replace the body with a computation which just passes through its
+  // parameter.
+  body->set_root_instruction(body->parameter_instruction(0));
+
+  // DCE the dead code in the body.
+  HloDCE dce;
+  TF_ASSERT_OK(dce.Run(module.get()).status());
+
+  EXPECT_EQ(schedule.sequence(body).size(), 7);
+  EXPECT_EQ(schedule.sequence(cond).size(), 4);
+
+  ASSERT_IS_NOT_OK(schedule.Verify());
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+
+  EXPECT_EQ(schedule.sequence(body).size(), 1);
+  EXPECT_EQ(schedule.sequence(cond).size(), 5);
+}
+
+TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
+  // Remove computations from a module and verify the schedule can be updated.
+  const string module_str = R"(
+HloModule UpdateScheduleWithMultipleComputations
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+  %param.1 = (s32[], token[]) parameter(0)
+  %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+  %constant.1 = s32[] constant(1)
+  %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+  %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+  %after-all = token[] after-all(token[] %get-tuple-element.2)
+  ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+  %param = (s32[], token[]) parameter(0)
+  %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+  %constant = s32[] constant(42)
+  ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %WhileLoop () -> s32[] {
+  %zero = s32[] constant(0)
+  %init_token = token[] after-all()
+  %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+  %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+  ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+                          ParseHloString(module_str));
+  TF_ASSERT_OK_AND_ASSIGN(
+      HloSchedule schedule,
+      ScheduleModule(*module, [](const BufferValue& buffer) {
+        return ShapeUtil::ByteSizeOf(buffer.shape(),
+                                     /*pointer_size=*/sizeof(void*));
+      }));
+
+  HloInstruction* xla_while =
+      module->entry_computation()->root_instruction()->mutable_operand(0);
+  HloInstruction* init = xla_while->mutable_operand(0);
+
+  // Replace the while with its init value. The conditional and body
+  // computations should then be dead.
+  TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
+
+  // DCE the dead code in the body.
+  HloDCE dce;
+  ASSERT_EQ(module->computation_count(), 3);
+  TF_ASSERT_OK(dce.Run(module.get()).status());
+  ASSERT_EQ(module->computation_count(), 1);
+
+  ASSERT_IS_NOT_OK(schedule.Verify());
+  TF_ASSERT_OK(schedule.Update());
+  TF_ASSERT_OK(schedule.Verify());
+}
+
+}  // namespace
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
deleted file mode 100644
index d06b8d9..0000000
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ /dev/null
@@ -1,127 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
-
-#include <vector>
-
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-
-namespace xla {
-
-// A memory scheduler computes an execution sequence for the HLO instructions in
-// 'computation' that minimizes peak memory, given a points-to analysis result
-// that describes buffer aliasing, together with a target-specific size function
-// that maps a tensor's logical size to its padded size.
-typedef std::function<StatusOr<std::vector<const HloInstruction*>>(
-    const HloComputation&, const TuplePointsToAnalysis&,
-    const LogicalBuffer::SizeFunction&,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
-    MemorySchedulerAlgorithm;
-
-// List scheduler
-StatusOr<std::vector<const HloInstruction*>> ListMemoryScheduler(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation);
-
-// DFS-order scheduler
-StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation);
-
-// Naive Post Order scheduler
-StatusOr<std::vector<const HloInstruction*>> PostOrderMemoryScheduler(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation);
-
-// The default scheduling algorithm. Runs both the list scheduler
-// and the DFS scheduler, and chooses whichever returns a lower min-memory,
-// not accounting for fragmentation.
-StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
-    const HloComputation& computation,
-    const TuplePointsToAnalysis& points_to_analysis,
-    const LogicalBuffer::SizeFunction& size_function,
-    const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
-        memory_by_computation);
-
-// Returns an HloModuleSequence which seeks to minimize the memory required for
-// the computation. size_function is the function returning the number of bytes
-// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
-    const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
-    const MemorySchedulerAlgorithm& algorithm = {});
-
-// Computes the schedule for a single computation.
-// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
-    const HloComputation& computation,
-    const LogicalBuffer::SizeFunction& size_function);
-
-// Transforms the given schedule such that it is (again) a valid schedule for
-// the module. This is used to update a schedule after the HLO module has been
-// transformed in some way. In general, the only transformations to the module
-// for which a schedule can be updated is the addition or removal of
-// instructions to/from the module. Updating the schedule after new dependencies
-// between existing instructions in the module is not supported and may result
-// in an error status returned.
-//
-// Instructions in the module which also exist in the given schedule will remain
-// in the same order in the updated schedule. Instructions which exist in the
-// module but not in the given schedule will be placed as early as possible in
-// the updated schedule.
-//
-// 'id_sequence' is a mirror of the given schedule 'sequence' but with
-// HloInstruction ids rather than HloInstruction pointers. This should be
-// constructed using ComputeIdSchedule below after the schedule is constructed
-// but before the HLO module is transformed.
-Status UpdateSchedule(
-    const HloModule& module,
-    const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>&
-        id_sequence,
-    SequentialHloOrdering::HloModuleSequence* sequence);
-
-// Constructs a copy of the given schedule but with HloInstruction unique ids
-// rather than HloInstruction pointers. This is necessary for updating a
-// schedule as HloInstruction points in the schedule may become invalid if
-// instructions are removed from the module. Used by UpdateSchedule above..
-// TODO(b/113175018): Remove this function when HLO schedule is its own class.
-tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence);
-
-// Verifies that the given schedule is valid for the given module. Specifically,
-// the schedule contains exactly the instructions in the module and every
-// dependency in the module is satisfied in the schedule.
-Status VerifySchedule(const HloModule& module,
-                      const SequentialHloOrdering::HloModuleSequence& sequence);
-
-}  // namespace xla
-
-#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
deleted file mode 100644
index d49d09d..0000000
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ /dev/null
@@ -1,667 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
-
-#include <memory>
-#include <string>
-
-#include "tensorflow/compiler/xla/service/heap_simulator.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_dce.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include "tensorflow/compiler/xla/service/hlo_parser.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-
-namespace xla {
-namespace {
-
-class HloSchedulingTest : public HloTestBase {};
-
-TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
-  // Tests scheduling of the following HLO code:
-  //
-  //   %ab = abs(%param)
-  //   %exp = exp(%param)
-  //   %add = add(%ab, %exp)
-  //   %negate = negate(%exp)
-  //   %sub = subtract(%add, %negate)
-  //
-  // %add should be scheduled before %negate because %add is the last (and only)
-  // use of %ab. Scheduling %add first then frees up %ab's buffer.
-  const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
-  auto builder = HloComputation::Builder(TestName());
-  auto param =
-      builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
-  auto ab = builder.AddInstruction(
-      HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
-  auto exp = builder.AddInstruction(
-      HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
-
-  auto add = builder.AddInstruction(
-      HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
-  auto negate = builder.AddInstruction(
-      HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
-  auto sub = builder.AddInstruction(
-      HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
-
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape());
-      }));
-  // Verify that all instructions are in the sequence.
-  EXPECT_EQ(module->entry_computation()->instruction_count(),
-            sequence.at(module->entry_computation()).size());
-
-  // The first instruction should be the parameter and the last the root "sub".
-  EXPECT_EQ(param, sequence.at(module->entry_computation()).front());
-  EXPECT_EQ(sub, sequence.at(module->entry_computation()).back());
-
-  SequentialHloOrdering ordering(module.get(), sequence);
-  EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
-}
-
-TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
-  const char* module_str = R"(
-HloModule test_aliasing_module
-
-ENTRY root {
-  param = s32[1000] parameter(0)
-  p0 = s32[1000] copy(param)
-  p1 = s32[1000] copy(param)
-  t = (s32[1000], s32[1000]) tuple(p0, p1)
-  a = s32[1000] get-tuple-element(t), index=0
-  b = s32[1000] get-tuple-element(t), index=1
-  c = s32[1000] add(a, b)
-  d = s32[1000] add(c, b)
-  e = s32[1000] add(c, c)
-  f = s32[1000] add(e, e)
-  ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
-})";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-
-  auto size_fn = [](const BufferValue& buffer) {
-    return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
-  };
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
-  // Verify that all instructions are in the sequence.
-  EXPECT_EQ(module->entry_computation()->instruction_count(),
-            sequence.at(module->entry_computation()).size());
-
-  std::unordered_map<string, const HloInstruction*> instructions_by_name;
-  for (const HloInstruction* instruction :
-       sequence.at(module->entry_computation())) {
-    instructions_by_name[instruction->name()] = instruction;
-  }
-
-  // The first instruction should be the parameter and the last the root.
-  EXPECT_EQ(instructions_by_name.at("param"),
-            sequence.at(module->entry_computation()).front());
-  EXPECT_EQ(instructions_by_name.at("result"),
-            sequence.at(module->entry_computation()).back());
-
-  // Instructions "d" and "e" will both be schedulable at the same time, but
-  // instruction "d" allows us to free the buffer of "p1", so the list scheduler
-  // should prefer it.
-  SequentialHloOrdering ordering(module.get(), sequence);
-  EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
-                                      instructions_by_name.at("e")));
-}
-
-TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
-  // %WhileCond (cond_param: f32[4]) -> pred[] {
-  //   %cond_param = f32[4]{0} parameter(0)
-  //   %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } })
-  //   ROOT %not-equal-to = pred[] not-equal-to(
-  //     f32[4]{0} %cond_param, f32[1,4]{1,0} %constant)
-  // }
-  // %WhileBody (body_param: f32[4]) -> f32[4] {
-  //   %body_param = f32[4]{0} parameter(0)
-  //   %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
-  //   ROOT %subtract = f32[4]{0} subtract(
-  //     f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
-  // }
-  // %ListAccountsForSubcomputations () -> f32[2,4] {
-  //   %constant.3 = f32[2,4]{1,0} constant(
-  //     f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
-  //   %transpose = f32[2,4]{1,0} transpose(
-  //     f32[2,4]{1,0} %constant.3), dimensions={0,1}
-  //   %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } })
-  //   %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2),
-  //      condition=%WhileCond,
-  //      body=%WhileBody
-  //   %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0}
-  //   ROOT %add = f32[2,4]{1,0} add(
-  //     f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast)
-  // }
-
-  auto module = CreateNewModule();
-  const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
-  const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
-
-  // param != 0
-  // Needs 17 bytes
-  auto cond_builder = HloComputation::Builder("WhileCond");
-  HloInstruction* cond_param = cond_builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1f32, "cond_param"));
-  HloInstruction* zero_vector =
-      cond_builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
-  cond_builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
-  auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
-
-  // param - 1
-  // Needs 16 bytes
-  auto body_builder = HloComputation::Builder("WhileBody");
-  HloInstruction* body_param = body_builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1f32, "body_param"));
-  HloInstruction* one_vector =
-      body_builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
-  body_builder.AddInstruction(HloInstruction::CreateBinary(
-      r1f32, HloOpcode::kSubtract, body_param, one_vector));
-  auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
-
-  // transpose(matrix) + bcast(while)
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* while_init =
-      builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
-  // Creates 16 bytes, ignoring subcomputations
-  HloInstruction* while_loop =
-      builder.AddInstruction(HloInstruction::CreateWhile(
-          r1f32, cond_computation, body_computation, while_init));
-
-  // Creates 32 bytes and frees 16
-  HloInstruction* bcast = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
-
-  HloInstruction* matrix = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
-          {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
-  // Creates 32 bytes
-  HloInstruction* transpose = builder.AddInstruction(
-      HloInstruction::CreateTranspose(r2f32, matrix, {0, 1}));
-
-  // Creates 32 bytes and frees 64
-  HloInstruction* add = builder.AddInstruction(
-      HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast));
-
-  module->AddEntryComputation(builder.Build());
-
-  auto size_fn = [](const BufferValue& buffer) {
-    return ShapeUtil::ByteSizeOf(buffer.shape());
-  };
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
-  // Verify that all instructions are in the sequence.
-  auto entry_computation = module->entry_computation();
-  EXPECT_EQ(entry_computation->instruction_count(),
-            sequence.at(entry_computation).size());
-  SequentialHloOrdering ordering(module.get(), sequence);
-  // This schedule is an example of List's greedy heuristics being suboptimal.
-  // The while_loop is more expensive than transpose, so it would have been
-  // better to schedule it first, instead of during the busy time.
-  EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop));
-  EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
-  EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
-  EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
-
-  tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
-  memory_by_computation[cond_computation] = 17;
-  memory_by_computation[body_computation] = 16;
-  std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
-      TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
-
-  // HeapSimulator doesn't account for subcomputations
-  EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
-                    *entry_computation, sequence.at(entry_computation),
-                    *points_to_analysis, size_fn)
-                    .ValueOrDie());
-  // HeapSimulator accounts for subcomputations. The output buffer is aliased,
-  // so we don't double count.
-  EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation(
-                    *entry_computation, sequence.at(entry_computation),
-                    *points_to_analysis, size_fn, &memory_by_computation)
-                    .ValueOrDie());
-}
-
-TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
-  auto builder = HloComputation::Builder(TestName());
-  const auto TUPLE_SIZE = 1;
-  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
-
-  // Wrap lit in abs because constants are considered free by
-  // IgnoreInstruction, and it skews the accounting.
-  auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
-  auto abs_const = builder.AddInstruction(
-      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
-
-  auto abs_abs1 = builder.AddInstruction(
-      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
-  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
-      absl::Span<HloInstruction* const>({abs_abs1})));
-  auto tuple_elm = builder.AddInstruction(
-      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
-
-  auto abs_abs2 = builder.AddInstruction(
-      HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
-
-  builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
-                                                      tuple_elm, abs_abs2));
-
-  auto module = CreateNewModule();
-  module->AddEntryComputation(builder.Build());
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module,
-                                   [](const BufferValue& buffer) {
-                                     return ShapeUtil::ByteSizeOf(
-                                         buffer.shape(), TUPLE_SIZE);
-                                   },
-                                   ListMemoryScheduler));
-
-  // Verify that all instructions are in the sequence.
-  EXPECT_EQ(module->entry_computation()->instruction_count(),
-            sequence.at(module->entry_computation()).size());
-  SequentialHloOrdering ordering(module.get(), sequence);
-  // tuple allocates the tuple buffer and doesn't free anything.
-  // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
-  // abs_abs2 should be scheduled before tuple by List.
-  EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
-}
-
-TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
-  const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
-  HloComputation::Builder builder(TestName());
-
-  auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
-  auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
-  auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
-      LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
-
-  auto add = builder.AddInstruction(
-      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
-  auto mul = builder.AddInstruction(
-      HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
-  auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
-
-  auto tuple_elm = builder.AddInstruction(
-      HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
-
-  auto exp = builder.AddInstruction(
-      HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
-
-  builder.AddInstruction(
-      HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
-
-  auto module = CreateNewModule();
-  auto* computation = module->AddEntryComputation(builder.Build());
-
-  auto fusion = computation->CreateFusionInstruction(
-      {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
-
-  TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
-                          ScheduleComputationsInModule(
-                              *module,
-                              [](const BufferValue& buffer) {
-                                return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
-                              },
-                              ListMemoryScheduler));
-
-  // Verify that all instructions are in the sequence.
-  EXPECT_EQ(module->entry_computation()->instruction_count(),
-            sequence.at(module->entry_computation()).size());
-  SequentialHloOrdering ordering(module.get(), sequence);
-  // fusion allocates memory for the tuple elements and doesn't free anything,
-  // so it's more expensive than exp.
-  EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
-}
-
-TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
-  auto module = CreateNewModule();
-  const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
-
-  // param != 0
-  // Needs 17 bytes
-  auto cond_builder = HloComputation::Builder("WhileCond");
-  HloInstruction* cond_param = cond_builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1f32, "cond_param"));
-  HloInstruction* zero_vector =
-      cond_builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
-  cond_builder.AddInstruction(HloInstruction::CreateBinary(
-      ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
-  auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
-
-  // param - 1
-  // Needs 16 bytes
-  auto body_builder = HloComputation::Builder("WhileBody");
-  HloInstruction* body_param = body_builder.AddInstruction(
-      HloInstruction::CreateParameter(0, r1f32, "body_param"));
-  HloInstruction* one_vector =
-      body_builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
-  body_builder.AddInstruction(HloInstruction::CreateBinary(
-      r1f32, HloOpcode::kSubtract, body_param, one_vector));
-  auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
-
-  auto builder = HloComputation::Builder(TestName());
-  HloInstruction* while_init =
-      builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
-  // Creates 16 bytes, ignoring subcomputations
-  builder.AddInstruction(HloInstruction::CreateWhile(
-      r1f32, cond_computation, body_computation, while_init));
-
-  module->AddEntryComputation(builder.Build());
-
-  auto size_fn = [](const BufferValue& buffer) {
-    return ShapeUtil::ByteSizeOf(buffer.shape());
-  };
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
-  // Verify that all instructions are in the sequence.
-  auto entry_computation = module->entry_computation();
-  EXPECT_EQ(entry_computation->instruction_count(),
-            sequence.at(entry_computation).size());
-
-  tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
-  memory_by_computation[cond_computation] = 17;
-  memory_by_computation[body_computation] = 16;
-  std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
-      TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
-
-  // HeapSimulator doesn't account for subcomputations
-  EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
-                    *entry_computation, sequence.at(entry_computation),
-                    *points_to_analysis, size_fn)
-                    .ValueOrDie());
-  // HeapSimulator accounts for subcomputations. Cond is the largest one.
-  // The output buffer of the while is aliased.
-  EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
-                    *entry_computation, sequence.at(entry_computation),
-                    *points_to_analysis, size_fn, &memory_by_computation)
-                    .ValueOrDie());
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) {
-  // Updating the schedule of an unchanged HLO module should not affect the
-  // schedule at all.
-  const string module_str = R"(
-HloModule UpdateScheduleUnchanged
-
-ENTRY main {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  c = f32[] constant(42.0)
-  sum = f32[] add(a, b)
-  neg = f32[] negate(c)
-  ROOT root = f32[] multiply(sum, neg)
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape());
-      }));
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-      id_sequence = ComputeIdSchedule(sequence);
-  std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second;
-
-  EXPECT_EQ(entry_schedule.size(), 6);
-
-  TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
-  TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
-  EXPECT_EQ(entry_schedule, sequence.begin()->second);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) {
-  // Add some additional instructions to a module and verify the schedule can be
-  // updated.
-  const string module_str = R"(
-HloModule UpdateScheduleWithNewInstructions
-
-ENTRY main {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  c = f32[] constant(42.0)
-  sum = f32[] add(a, b)
-  neg = f32[] negate(c)
-  ROOT root = f32[] multiply(sum, neg)
-}
-)";
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape());
-      }));
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-      id_sequence = ComputeIdSchedule(sequence);
-
-  HloComputation* entry = module->entry_computation();
-  const Shape shape = entry->root_instruction()->shape();
-  HloInstruction* constant = entry->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
-  HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
-      shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
-  entry->set_root_instruction(sub);
-
-  auto in_schedule = [&](const HloInstruction* hlo) {
-    return std::find(sequence.at(entry).begin(), sequence.at(entry).end(),
-                     hlo) != sequence.at(entry).end();
-  };
-
-  EXPECT_EQ(sequence.at(entry).size(), 6);
-  EXPECT_FALSE(in_schedule(constant));
-  EXPECT_FALSE(in_schedule(sub));
-
-  TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
-  TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
-  EXPECT_EQ(sequence.at(entry).size(), 8);
-  EXPECT_TRUE(in_schedule(constant));
-  EXPECT_TRUE(in_schedule(sub));
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) {
-  // Add and delete some instructions from a module and verify that the schedule
-  // can be updated successfully.
-  const string module_str = R"(
-HloModule UpdateScheduleWithAddedAndDeletedInstruction
-
-ENTRY main {
-  a = f32[] parameter(0)
-  b = f32[] parameter(1)
-  c = f32[] constant(42.0)
-  sum = f32[] add(a, b)
-  neg = f32[] negate(c)
-  ROOT root = f32[] multiply(sum, neg)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape());
-      }));
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-      id_sequence = ComputeIdSchedule(sequence);
-
-  // Set the entry root to some expression containing just a parameter and a
-  // constant.
-  HloComputation* entry = module->entry_computation();
-  HloInstruction* constant = entry->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
-  HloInstruction* new_root = entry->AddInstruction(
-      HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
-                                   constant, entry->parameter_instruction(0)));
-  entry->set_root_instruction(new_root);
-
-  // DCE should remove everything but the parameters and the newly added code.
-  HloDCE dce;
-  TF_ASSERT_OK(dce.Run(module.get()).status());
-
-  EXPECT_EQ(sequence.at(entry).size(), 6);
-
-  TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
-  TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
-  EXPECT_EQ(sequence.at(entry).size(), 4);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) {
-  // Completely replace a module with an entirely new set of instructions and
-  // verify that the schedule can be updated successfully.
-  const string module_str = R"(
-HloModule UpdateScheduleWithCompletelyReplacedModule
-
-ENTRY main {
-  a = f32[] constant(42.0)
-  b = f32[] constant(123.0)
-  ROOT sum = f32[] add(a, b)
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape());
-      }));
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-      id_sequence = ComputeIdSchedule(sequence);
-
-  // Replace the entry computation with the negation of a constant.
-  HloComputation* entry = module->entry_computation();
-  HloInstruction* constant = entry->AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
-  HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
-      constant->shape(), HloOpcode::kNegate, constant));
-  entry->set_root_instruction(new_root);
-
-  // DCE the old instructions.
-  HloDCE dce;
-  TF_ASSERT_OK(dce.Run(module.get()).status());
-
-  EXPECT_EQ(sequence.at(entry).size(), 3);
-
-  TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
-  TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
-  EXPECT_EQ(sequence.at(entry).size(), 2);
-}
-
-TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) {
-  // Create changes to more than one computation in an HLO module and verify
-  // that the schedule can be updated.
-  const string module_str = R"(
-HloModule UpdateScheduleWithMultipleComputations
-
-%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
-  %param.1 = (s32[], token[]) parameter(0)
-  %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
-  %constant.1 = s32[] constant(1)
-  %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
-  %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
-  %after-all = token[] after-all(token[] %get-tuple-element.2)
-  ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
-}
-
-%Cond (param: (s32[], token[])) -> pred[] {
-  %param = (s32[], token[]) parameter(0)
-  %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
-  %constant = s32[] constant(42)
-  ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
-}
-
-ENTRY %WhileLoop () -> s32[] {
-  %zero = s32[] constant(0)
-  %init_token = token[] after-all()
-  %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
-  %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
-  ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
-}
-)";
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
-                          ParseHloString(module_str));
-  TF_ASSERT_OK_AND_ASSIGN(
-      SequentialHloOrdering::HloModuleSequence sequence,
-      ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
-        return ShapeUtil::ByteSizeOf(buffer.shape(),
-                                     /*pointer_size=*/sizeof(void*));
-      }));
-  tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>
-      id_sequence = ComputeIdSchedule(sequence);
-
-  const HloInstruction* xla_while =
-      module->entry_computation()->root_instruction()->operand(0);
-  HloComputation* body = xla_while->while_body();
-  HloComputation* cond = xla_while->while_condition();
-
-  // Negate the root of the cond.
-  cond->set_root_instruction(cond->AddInstruction(
-      HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
-                                  HloOpcode::kNot, cond->root_instruction())));
-
-  // Replace the body with a computation which just passes through its
-  // parameter.
-  body->set_root_instruction(body->parameter_instruction(0));
-
-  // DCE the dead code in the body.
-  HloDCE dce;
-  TF_ASSERT_OK(dce.Run(module.get()).status());
-
-  EXPECT_EQ(sequence.at(body).size(), 7);
-  EXPECT_EQ(sequence.at(cond).size(), 4);
-
-  TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence));
-  TF_ASSERT_OK(VerifySchedule(*module, sequence));
-
-  EXPECT_EQ(sequence.at(body).size(), 1);
-  EXPECT_EQ(sequence.at(cond).size(), 5);
-}
-
-}  // namespace
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 34cba61..e3f4a98 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -422,6 +422,13 @@
              : false;
 }
 
+size_t ShardingMetadata::Hash() const {
+  if (sharding_ != nullptr) {
+    return sharding_->Hash();
+  }
+  return static_cast<size_t>(0x297814aaad196e6dULL);
+}
+
 string ShardingMetadata::ToString() const {
   return sharding_ != nullptr ? sharding_->ToString() : "{}";
 }
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
index cba5db9..e3ae82a 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h
@@ -36,6 +36,8 @@
 
   bool Matches(const DomainMetadata& other) const override;
 
+  size_t Hash() const override;
+
   string ToString() const override;
 
   const HloSharding* sharding() const { return sharding_.get(); }
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a..6fd734a 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 
@@ -24,7 +24,7 @@
 
 using ::tensorflow::GraphDef;
 
-class HloTfGraphBuilderTest : public HloTestBase {
+class HloTfGraphBuilderTest : public HloVerifiedTestBase {
  protected:
   HloTfGraphBuilderTest() {}
   HloTfGraphBuilder generator_;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 95516de..50f39cb 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -86,8 +86,8 @@
       const Shape expected,
       ShapeInference::InferConvolveShape(
           convolution->operand(0)->shape(), convolution->operand(1)->shape(),
-          convolution->window(), convolution->convolution_dimension_numbers(),
-          convolution->feature_group_count()));
+          convolution->feature_group_count(), convolution->window(),
+          convolution->convolution_dimension_numbers()));
   return CheckShape(convolution, expected);
 }
 
@@ -1123,6 +1123,11 @@
 
   TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
 
+  // If the module has a schedule, it must be valid.
+  if (module->has_schedule()) {
+    TF_RETURN_IF_ERROR(module->schedule().Verify());
+  }
+
   return false;
 }
 
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 0cac210..8f0423b 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -290,8 +290,8 @@
   padding_config.add_dimensions()->set_interior_padding(-1);
   builder.AddInstruction(HloInstruction::CreatePad(
       ShapeUtil::MakeShape(F32, {100}), param,
-      builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::Zero(F32).CloneToUnique())),
+      builder.AddInstruction(
+          HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
       padding_config));
 
   auto module = CreateNewModule();
@@ -314,8 +314,8 @@
   padding_config.add_dimensions()->set_interior_padding(-1);
   builder.AddInstruction(HloInstruction::CreatePad(
       ShapeUtil::MakeShape(F32, {100}), param,
-      builder.AddInstruction(HloInstruction::CreateConstant(
-          LiteralUtil::Zero(F32).CloneToUnique())),
+      builder.AddInstruction(
+          HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
       padding_config));
 
   auto module = CreateNewModule();
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index a4de02a..06f0e1e 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -165,6 +165,7 @@
     TF_ASSIGN_OR_RETURN(
         computed_array,
         ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+                           instr->precision_config(),
                            FindOrDie(cache_, instr->operand(0)),
                            FindOrDie(cache_, instr->operand(1))));
   } else {
@@ -917,7 +918,7 @@
   // inner_broadcast_result is the Broadcast'(Const0) bit in
   // BinaryOp(Broadcast'(Const0), Const1)
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<Literal> inner_broadcast_result,
+      Literal inner_broadcast_result,
       broadcast_const_operand->literal().Broadcast(
           scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
 
@@ -927,12 +928,12 @@
     TF_ASSIGN_OR_RETURN(
         literal_for_new_source,
         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
-            opcode, scalar_indexed_const->literal(), *inner_broadcast_result)));
+            opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
   } else {
     TF_ASSIGN_OR_RETURN(
         literal_for_new_source,
         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
-            opcode, *inner_broadcast_result, scalar_indexed_const->literal())));
+            opcode, inner_broadcast_result, scalar_indexed_const->literal())));
   }
 
   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
@@ -1030,7 +1031,8 @@
 StatusOr<Analysis::Array*>
 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
     const Shape& shape, const DotDimensionNumbers& dim_numbers,
-    ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
+    const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+    ConstantArray* rhs) {
   VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
           << ToString(rhs);
   if (!CanFoldDotIntoIndexedArray(
@@ -1045,9 +1047,10 @@
   new_dim_numbers.set_lhs_contracting_dimensions(
       0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
 
-  TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
-                      TakeOwnership(HloEvaluator{}.EvaluateDotOp(
-                          new_dim_numbers, lhs->literal(), *rhs->literal())));
+  TF_ASSIGN_OR_RETURN(
+      Literal * literal_for_new_source,
+      TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+          new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
 
   // The new source dimension is wherever the non-batch non-contracting LHS
   // dimension "went".
@@ -1063,7 +1066,8 @@
 StatusOr<Analysis::Array*>
 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
     const Shape& shape, const DotDimensionNumbers& dim_numbers,
-    ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+    const PrecisionConfig& precision_config, ConstantArray* lhs,
+    ScalarIndexedConstantArray* rhs) {
   VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
           << ToString(rhs);
   if (!CanFoldDotIntoIndexedArray(
@@ -1079,9 +1083,10 @@
   new_dim_numbers.set_rhs_contracting_dimensions(
       0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
 
-  TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
-                      TakeOwnership(HloEvaluator{}.EvaluateDotOp(
-                          new_dim_numbers, *lhs->literal(), rhs->literal())));
+  TF_ASSIGN_OR_RETURN(
+      Literal * literal_for_new_source,
+      TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+          new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
 
   // The new source dimension is wherever the non-batch non-contracting RHS
   // dimension "went".
@@ -1095,8 +1100,8 @@
 }
 
 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
-    const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
-    Array* rhs) {
+    const Shape& shape, const DotDimensionNumbers& dim_numbers,
+    const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
   // Intuitively, if
   //
   //  - The LHS of a dot product is a gathered sequence of rows from a constant
@@ -1119,6 +1124,7 @@
           dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
     if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
       return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+                                              precision_config,
                                               lhs_indexed_array, rhs_constant);
     }
   }
@@ -1126,7 +1132,8 @@
   if (auto* rhs_indexed_array =
           dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
     if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
-      return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+      return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
+                                              precision_config, lhs_constant,
                                               rhs_indexed_array);
     }
   }
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index dcfb725..df9cbab 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -267,14 +267,17 @@
 
   StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
       const Shape& shape, const DotDimensionNumbers& dim_numbers,
-      ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
+      const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
+      ConstantArray* rhs);
 
   StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
       const Shape& shape, const DotDimensionNumbers& dim_numbers,
-      ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+      const PrecisionConfig& precision_config, ConstantArray* lhs,
+      ScalarIndexedConstantArray* rhs);
 
   StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
                                       const DotDimensionNumbers& dim_numbers,
+                                      const PrecisionConfig& precision_config,
                                       Array* lhs, Array* rhs);
 
   // This tries to fold a ScalarIndexedArray which has another
@@ -344,21 +347,19 @@
     }
   }
 
-  Literal* TakeOwnership(std::unique_ptr<Literal> literal) {
+  Literal* TakeOwnership(Literal literal) {
     owned_literals_.push_back(std::move(literal));
-    return owned_literals_.back().get();
+    return &owned_literals_.back();
   }
 
-  StatusOr<Literal*> TakeOwnership(
-      StatusOr<std::unique_ptr<Literal>> literal_or_error) {
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
-                        std::move(literal_or_error));
+  StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
+    TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
     owned_literals_.push_back(std::move(literal));
-    return owned_literals_.back().get();
+    return &owned_literals_.back();
   }
 
   std::vector<std::unique_ptr<Array>> owned_tensors_;
-  std::vector<std::unique_ptr<Literal>> owned_literals_;
+  std::vector<Literal> owned_literals_;
   tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
 };
 
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 5695bc2..7e967f0 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -26,7 +26,7 @@
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 
@@ -35,7 +35,7 @@
 namespace xla {
 namespace {
 
-using InlinerTest = HloTestBase;
+using InlinerTest = HloVerifiedTestBase;
 
 // Test that `map` with `max` is transformed to `max`
 TEST_F(InlinerTest, MapMax) {
@@ -64,14 +64,14 @@
   hlo_module->AddEntryComputation(std::move(computation));
 
   Inliner inliner;
-  EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+  EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
   EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
               op::Maximum(lhs, rhs));
 
   // Verify execution on CPU.
-  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+  auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
   auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 }
 
 // Test that `constant` function is changed to `broadcast`.
@@ -98,14 +98,14 @@
   hlo_module->AddEntryComputation(std::move(computation));
   HloInstruction* root = hlo_module->entry_computation()->root_instruction();
   Inliner inliner;
-  EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+  EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
   root = hlo_module->entry_computation()->root_instruction();
   EXPECT_THAT(root, op::Broadcast(op::Constant()));
 
   // Verify execution on CPU.
-  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+  auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
   auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 }
 
 TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -136,14 +136,14 @@
   hlo_module->AddEntryComputation(std::move(computation));
 
   Inliner inliner;
-  EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+  EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
   EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
           op::Subtract(rhs, lhs));
 
   // Verify execution on CPU.
-  auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+  auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
   auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
 }
 
 
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907ea..3fdc2ce 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
 #include "tensorflow/compiler/xla/map_util.h"
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@
   return do_not_duplicate;
 }
 
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+  explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+    post_order_ = computation->MakeInstructionPostOrder();
+
+    for (size_t i = 0; i < post_order_.size(); ++i) {
+      InsertOrDie(&post_order_index_, post_order_[i], i);
+    }
+  }
+
+  std::pair<HloInstruction*, std::vector<int64>>
+  DequeueNextInstructionAndOperandsToFuseInOrder() override {
+    // Instructions are "removed" from the post order by nulling out the element
+    // in the vector, so if the pointer is null, continue to the next
+    // instruction in the sort.
+    while (!post_order_.empty() && post_order_.back() == nullptr) {
+      post_order_.pop_back();
+    }
+    if (post_order_.empty()) {
+      return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+    }
+    // We want to iterate in reverse post order, so remove from the back of the
+    // vector.
+    HloInstruction* instruction = post_order_.back();
+    post_order_.pop_back();
+
+    CHECK(instruction != nullptr);
+    // Remove instruction from the index map to ensure the vector and map stay
+    // consistent.
+    post_order_index_.erase(instruction);
+
+    // Consider each operand of this instruction for fusion into this
+    // instruction. We want to consider the operands in a particular order to
+    // avoid creating duplicate instruction clones in the fusion instruction.
+    // For example, consider the following expression:
+    //
+    //   A = ...
+    //   B = op(A)
+    //   C = op(A, B)
+    //
+    // If we are considering the operands of C for fusion into C. We might
+    // fuse A or B first. If we fuse A first, we get:
+    //
+    //   A = ...
+    //   B = op(A)
+    //   C_fusion = { A' = ...
+    //                C' = op(A', B) }
+    //
+    // Where A' and C' are clones of A and C, respectively. Now only B is an
+    // operand of the fusion instruction C_fusion, so then we fuse B:
+    //
+    //   A = ...
+    //   B = op(A)
+    //   C_fusion = { A' = ...
+    //                B' = op(A)
+    //                C' = op(A', B') }
+    //
+    // Now A is an operand of C_fusion again, so we then fuse A (again!):
+    //
+    //   A = ...
+    //   B = op(A)
+    //   C_fusion = { A' = ...
+    //                A" = ..
+    //                B' = op(A")
+    //                C' = op(A', B') }
+    //
+    // We prevent this duplication by considering the operands in the order
+    // they appear int the queue. In the example, this ensures that B will be
+    // considered before A.
+    //
+    // We store the original indices of the operands to pass to ShouldFuse.
+    std::vector<int64> sorted_operand_numbers;
+    sorted_operand_numbers.reserve(instruction->operands().size());
+    for (int i = 0; i < instruction->operands().size(); ++i) {
+      // This will happen if we have two possible instructions to fuse the
+      // same operand into; once the operand is fused into one instruction,
+      // the other instruction will get a new get-tuple-element as its
+      // operand, which is not in the queue.
+      // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+      if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+        continue;
+      }
+      sorted_operand_numbers.push_back(i);
+    }
+    std::sort(
+        sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+        [&](int64 i, int64 j) {
+          // Instructions with higher priority in the queue come first.
+          return (
+              FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+              FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+        });
+    return std::make_pair(instruction, sorted_operand_numbers);
+  }
+
+  void OnFusingInstruction(HloInstruction* fusion,
+                           HloInstruction* original_producer,
+                           HloInstruction* original_consumer) override {
+    // Fusing an instruction into a fusion instruction can change the operand
+    // set of the fusion instruction. For simplicity just re-enqueue the
+    // instruction and reconsider it for further fusion in the next iteration.
+    InsertOrDie(&post_order_index_, fusion, post_order_.size());
+    post_order_.push_back(fusion);
+  }
+
+  void RemoveInstruction(HloInstruction* instruction) override {
+    post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+    post_order_index_.erase(instruction);
+  }
+
+ private:
+  std::vector<HloInstruction*> post_order_;
+  tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+}  // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+    HloComputation* computation,
+    const std::function<bool(HloInstruction*)>& skip_producer) {
+  return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
 StatusOr<bool> InstructionFusion::Run(HloModule* module) {
   VLOG(2) << "Before instruction fusion:";
   XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@
     computation_ = computation;
     reachability_ = computation_->ComputeReachability();
 
-    // We want to be able to remove arbitrary instructions from the post order
-    // and also compare positions of instructions in the post order. To make
-    // this possible, create vector of instructions in post order and create a
-    // map from HloInstruction* to the instruction's index in the vector. An
-    // instruction is "removed" from the vector by setting it's element to
-    // nullptr.
-    std::vector<HloInstruction*> post_order =
-        computation_->MakeInstructionPostOrder();
-
-    tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
-    for (size_t i = 0; i < post_order.size(); ++i) {
-      InsertOrDie(&post_order_index, post_order[i], i);
-    }
-
-    HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+    HloInstructionSet do_not_duplicate =
+        ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+    auto fusion_queue =
+        GetFusionQueue(computation_, [&](HloInstruction* producer) {
+          return do_not_duplicate.count(producer) > 0;
+        });
 
     // Instruction fusion effectively fuses edges in the computation graph
     // (producer instruction -> consumer instruction) so we iterate over all
     // edges. When we fuse an edge, we create a copy of the producer inside the
     // fusion instruction.
-    while (!post_order.empty()) {
-      // We want to iterate in reverse post order, so remove from the back of
-      // the vector.
-      HloInstruction* instruction = post_order.back();
-      post_order.pop_back();
-
-      // Instructions are "removed" from the post order by nulling out the
-      // element in the vector, so if the pointer is null, continue to the next
-      // instruction in the sort.
+    while (true) {
+      auto next_entry =
+          fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+      auto instruction = next_entry.first;
       if (instruction == nullptr) {
-        continue;
+        break;
       }
 
-      // Remove instruction from the index map to ensure the vector and map stay
-      // consistent.
-      post_order_index.erase(instruction);
-
       if (!instruction->IsFusible() &&
           instruction->opcode() != HloOpcode::kFusion) {
         continue;
       }
 
-      // Consider each operand of this instruction for fusion into this
-      // instruction. We want to consider the operands in a particular order to
-      // avoid creating duplicate instruction clones in the fusion instruction.
-      // For example, consider the following expression:
-      //
-      //   A = ...
-      //   B = op(A)
-      //   C = op(A, B)
-      //
-      // If we are considering the operands of C for fusion into C. We might
-      // fuse A or B first. If we fuse A first, we get:
-      //
-      //   A = ...
-      //   B = op(A)
-      //   C_fusion = { A' = ...
-      //                C' = op(A', B) }
-      //
-      // Where A' and C' are clones of A and C, respectively. Now only B is an
-      // operand of the fusion instruction C_fusion, so then we fuse B:
-      //
-      //   A = ...
-      //   B = op(A)
-      //   C_fusion = { A' = ...
-      //                B' = op(A)
-      //                C' = op(A', B') }
-      //
-      // Now A is an operand of C_fusion again, so we then fuse A (again!):
-      //
-      //   A = ...
-      //   B = op(A)
-      //   C_fusion = { A' = ...
-      //                A" = ..
-      //                B' = op(A")
-      //                C' = op(A', B') }
-      //
-      // We prevent this duplication by considering the operands in the reverse
-      // order they appear in the instruction post order. In the example, this
-      // ensures that B will be considered before A.
-      //
-      // We store the original indices of the operands to pass to ShouldFuse.
-      std::vector<int64> sorted_operand_numbers;
-      sorted_operand_numbers.reserve(instruction->operands().size());
-      for (int i = 0; i < instruction->operands().size(); ++i) {
-        // This will happen if we have two possible instructions to fuse the
-        // same operand into; once the operand is fused into one instruction,
-        // the other instruction will get a new get-tuple-element as its
-        // operand, which is not in the post-order index.
-        // TODO(tjoerg): Look into fusing past these multi-output fuse points.
-        if (post_order_index.find(instruction->mutable_operand(i)) ==
-            post_order_index.end()) {
-          continue;
-        }
-        sorted_operand_numbers.push_back(i);
-      }
-      std::sort(
-          sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
-          [&](int64 i, int64 j) {
-            // Instructions with higher indices in the post order come
-            // first.
-            return (
-                FindOrDie(post_order_index, instruction->mutable_operand(i)) >
-                FindOrDie(post_order_index, instruction->mutable_operand(j)));
-          });
+      std::vector<int64>& sorted_operand_numbers = next_entry.second;
 
       for (int64 i : sorted_operand_numbers) {
         HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@
         // TODO(tjoerg): Consider making multi-output fusion the default.
         if (ShouldFuse(instruction, i) &&
             do_not_duplicate.count(operand) == 0) {
+          fusion_queue->PreFusion(operand, instruction);
           fusion_instruction = Fuse(operand, instruction);
         } else if (ShouldFuseIntoMultiOutput(instruction, i) &&
                    !MultiOutputFusionCreatesCycle(operand, instruction)) {
+          fusion_queue->PreFusion(operand, instruction);
           fusion_instruction = FuseIntoMultiOutput(operand, instruction);
         } else {
           continue;
         }
 
-        // Fusing an instruction into a fusion instruction can change the
-        // operand set of the fusion instruction. For simplicity just push the
-        // instruction to the top of the post_order and reconsider it for
-        // further fusion in the next iteration of the outer loop.
-        post_order.push_back(fusion_instruction);
-        InsertOrDie(&post_order_index, fusion_instruction,
-                    post_order.size() - 1);
+        fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+                                          instruction);
         changed = true;
 
         if (operand->user_count() == 0) {
-          // Operand is now dead. Remove from post order by setting its
-          // location to nullptr.
-          post_order[FindOrDie(post_order_index, operand)] = nullptr;
-          post_order_index.erase(operand);
-
+          do_not_duplicate.erase(operand);
+          // Operand is now dead. Remove from queue.
+          fusion_queue->RemoveInstruction(operand);
           // Remove from computation.
           TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
         }
+
+        if (fusion_instruction != instruction) {
+          do_not_duplicate.erase(instruction);
+        }
         break;
       }
     }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 00b6589..c1fde8e 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,6 +24,33 @@
 
 namespace xla {
 
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+  FusionQueue() = default;
+  virtual ~FusionQueue() = default;
+
+  // Dequeues the next fusion candidates: a consumer and the list of producers
+  // as operand indices.
+  virtual std::pair<HloInstruction*, std::vector<int64>>
+  DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+  // A callback passed to the queue implementation right before the producer is
+  // fused into the consumer.
+  virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+  // A callback passed to the queue implementation right after the fusion is
+  // created. Note that original_producer could have been destroyed.
+  virtual void OnFusingInstruction(HloInstruction* fusion,
+                                   HloInstruction* original_producer,
+                                   HloInstruction* original_consumer) {}
+
+  // A callback passed to the queue implementation to notify the removal of an
+  // instruction.
+  virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
 // HLO pass which performs instruction fusion. Instructions are fused
 // "vertically", meaning producing instructions are fused into their consumers
 // with the intent that the loops which compute their values will be fused in
@@ -48,6 +75,13 @@
   static bool IsExpensive(const HloInstruction& instruction);
 
  protected:
+  // Returns a FusionQueue that implements custom order of instructions being
+  // fused. The default implementation processes consumers in reverse post
+  // order.
+  virtual std::unique_ptr<FusionQueue> GetFusionQueue(
+      HloComputation* computation,
+      const std::function<bool(HloInstruction*)>& skip_producer);
+
   // Returns whether the given producer instruction should be fused into the
   // given consumer instruction. producer is necessarily an operand of consumer.
   // Derived classes should define this method to specify which instructions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 5dea124..a06d611 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -73,30 +73,29 @@
 
   // Transform the ShapedBuffer arguments into literals which the evaluator
   // consumes.
-  std::vector<std::unique_ptr<Literal>> arg_literals;
+  std::vector<Literal> arg_literals;
   for (int64 p = 0; p < computation->num_parameters(); ++p) {
-    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+    TF_ASSIGN_OR_RETURN(Literal arg_literal,
                         transfer_manager->TransferLiteralFromDevice(
                             run_options->stream(), *arguments[p]));
     arg_literals.push_back(std::move(arg_literal));
   }
 
   // Execute the graph using the HloEvaluator.
-  std::unique_ptr<Literal> result_literal;
+  Literal result_literal;
   {
     tensorflow::mutex_lock lock(evaluator_lock_);
-    TF_ASSIGN_OR_RETURN(result_literal,
-                        evaluator_->Evaluate<std::unique_ptr<Literal>>(
-                            *computation, arg_literals));
+    TF_ASSIGN_OR_RETURN(result_literal, evaluator_->Evaluate<Literal>(
+                                            *computation, arg_literals));
   }
 
   // Transform the result literal back into a ShapedBuffer.
   TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
                       transfer_manager->AllocateScopedShapedBuffer(
-                          result_literal->shape(), run_options->allocator(),
+                          result_literal.shape(), run_options->allocator(),
                           executor->device_ordinal()));
   TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
-      run_options->stream(), *result_literal, result));
+      run_options->stream(), result_literal, result));
 
   uint64 end_micros = tensorflow::Env::Default()->NowMicros();
 
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 6e17711..082bf8b 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -855,8 +855,7 @@
             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
             : instruction.sharding();
     // We propagate the sharding to the copied instruction only if it is a
-    // special sharding, like tiled ones, or special devices like the
-    // HostCompute module.
+    // special sharding, like tiled ones.
     // Otherwise it is preferable to leave the new instruction without device,
     // and let the automatic device placer to choose the best location.
     auto device = sharding.UniqueDevice();
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 021fe63..752a614 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -35,7 +35,7 @@
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -49,7 +49,7 @@
 
 using ::testing::ElementsAre;
 
-class LayoutAssignmentTest : public HloTestBase {
+class LayoutAssignmentTest : public HloVerifiedTestBase {
  protected:
   void AssignLayouts(HloModule* module,
                      ComputationLayout* entry_computation_layout,
@@ -91,7 +91,7 @@
     *computation_layout.mutable_parameter_layout(0) = shape_layout;
     *computation_layout.mutable_parameter_layout(1) = shape_layout;
     *computation_layout.mutable_result_layout() = shape_layout;
-    AssignLayouts(module.get(), &computation_layout);
+    AssignLayouts(module, &computation_layout);
     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
@@ -127,7 +127,7 @@
   *computation_layout.mutable_parameter_layout(1) = row_major;
   *computation_layout.mutable_result_layout() = col_major;
 
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
   EXPECT_TRUE(LayoutUtil::Equal(
@@ -145,7 +145,7 @@
         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
     auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
-    Shape ashape = constant_literal1->shape();
+    Shape ashape = constant_literal1.shape();
 
     auto constant1 = builder.AddInstruction(
         HloInstruction::CreateConstant(std::move(constant_literal1)));
@@ -172,7 +172,7 @@
     ComputationLayout computation_layout(computation->ComputeProgramShape());
     *computation_layout.mutable_result_layout() = shape_layout;
 
-    AssignLayouts(module.get(), &computation_layout);
+    AssignLayouts(module, &computation_layout);
 
     EXPECT_TRUE(LayoutUtil::Equal(
         layout, fusion->fused_parameter(0)->shape().layout()));
@@ -213,7 +213,7 @@
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape());
 
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   EXPECT_TRUE(
       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
@@ -243,7 +243,7 @@
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
 
   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
-      tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+      tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
 
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
@@ -255,7 +255,7 @@
   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
       result_shape));
 
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
 }
@@ -294,7 +294,7 @@
       result_shape));
 
   LayoutAssignment layout_assignment(&computation_layout);
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   // Layout assignment should have deep copied the result of the computation to
   // address the layout conflict. This results in several Tuple() and
@@ -310,7 +310,7 @@
   EXPECT_TRUE(
       AlgebraicSimplifier(/*is_layout_sensitive=*/true,
                           [](const Shape&, const Shape&) { return false; })
-          .Run(module.get())
+          .Run(module)
           .ValueOrDie());
   HloInstruction* root = module->entry_computation()->root_instruction();
   // Verify layout of the root and the root's operands.
@@ -352,7 +352,7 @@
   *computation_layout.mutable_parameter_layout(0) =
       ShapeLayout(ashape_with_layout);
   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   auto log_minor_to_major =
       AsInt64Slice(log->shape().layout().minor_to_major());
@@ -393,7 +393,7 @@
   *computation_layout.mutable_parameter_layout(0) =
       ShapeLayout(ashape_with_layout);
   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   EXPECT_TRUE(
       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
@@ -432,7 +432,7 @@
       ShapeLayout(input_shape_with_layout);
   *computation_layout.mutable_result_layout() =
       ShapeLayout(output_shape_with_layout);
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
               ElementsAre(0, 1, 2));
@@ -457,13 +457,13 @@
   auto param = builder.AddInstruction(
       HloInstruction::CreateParameter(0, f32_4, "param"));
   auto broadcast = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(f32_34, param, {3}));
+      HloInstruction::CreateBroadcast(f32_34, param, {1}));
   auto transpose = builder.AddInstruction(
       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
   auto tanh = builder.AddInstruction(
       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
   auto broadcast2 = builder.AddInstruction(
-      HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+      HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
   auto tuple = builder.AddInstruction(
       HloInstruction::CreateTuple({transpose, broadcast2}));
   auto module = CreateNewModule();
@@ -485,7 +485,7 @@
   *computation_layout.mutable_result_layout() =
       ShapeLayout(ShapeUtil::MakeTupleShape(
           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
@@ -551,7 +551,7 @@
   *computation_layout.mutable_parameter_layout(1) =
       ShapeLayout(param1_shape_with_layout);
   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
-  EXPECT_IS_OK(layout_assignment.Run(module.get()).status());
+  EXPECT_IS_OK(layout_assignment.Run(module).status());
 
   EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
   EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
@@ -575,7 +575,7 @@
   HloComputation* computation =
       module->AddEntryComputation(builder.Build(transpose));
   ComputationLayout computation_layout(computation->ComputeProgramShape());
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
                                             transpose->shape(), {2, 3, 0, 1}));
 }
@@ -593,7 +593,7 @@
   HloComputation* computation =
       module->AddEntryComputation(builder.Build(transpose));
   ComputationLayout computation_layout(computation->ComputeProgramShape());
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
                                             transpose->shape(), {2, 3, 0, 1}));
 }
@@ -659,18 +659,18 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
+  ParseAndVerifyModule(module_str);
 
-  module =
+  std::unique_ptr<HloModule> compiled_module =
       backend()
           .compiler()
-          ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
                          /*device_allocator=*/nullptr)
           .ConsumeValueOrDie();
 
   EXPECT_EQ(Status::OK(), backend()
                               .compiler()
-                              ->RunBackend(std::move(module),
+                              ->RunBackend(std::move(compiled_module),
                                            backend().default_stream_executor(),
                                            /*device_allocator=*/nullptr)
                               .status());
@@ -699,9 +699,9 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
+  ParseAndVerifyModule(module_str);
   ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape());
+      module().entry_computation()->ComputeProgramShape());
   Shape param_shape = ShapeUtil::MakeTupleShape(
       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
        ShapeUtil::MakeTupleShape({
@@ -713,19 +713,19 @@
           param_shape));
   computation_layout.mutable_result_layout()->ResetLayout(
       LayoutUtil::MakeLayout({2, 1, 0}));
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(&module(), &computation_layout);
 
-  EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
-  EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
-  EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
-  EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
-  EXPECT_THAT(FindInstruction(module.get(), "gte1")
+  EXPECT_THAT(LayoutOf(&module(), "gte0"), ElementsAre(0, 1, 2));
+  EXPECT_THAT(LayoutOf(&module(), "gte1a"), ElementsAre(1, 2, 0));
+  EXPECT_THAT(LayoutOf(&module(), "gte1b"), ElementsAre(2, 0, 1));
+  EXPECT_THAT(LayoutOf(&module(), "fresult"), ElementsAre(2, 1, 0));
+  EXPECT_THAT(FindInstruction(&module(), "gte1")
                   ->shape()
                   .tuple_shapes(0)
                   .layout()
                   .minor_to_major(),
               ElementsAre(1, 2, 0));
-  EXPECT_THAT(FindInstruction(module.get(), "gte1")
+  EXPECT_THAT(FindInstruction(&module(), "gte1")
                   ->shape()
                   .tuple_shapes(1)
                   .layout()
@@ -785,7 +785,7 @@
   HloComputation* computation = module->AddEntryComputation(builder.Build());
   ComputationLayout computation_layout(computation->ComputeProgramShape());
 
-  AssignLayouts(module.get(), &computation_layout);
+  AssignLayouts(module, &computation_layout);
 
   const HloInstruction* true_root = true_computation->root_instruction();
   const HloInstruction* false_root = false_computation->root_instruction();
@@ -812,7 +812,7 @@
   ComputationLayout computation_layout(
       module->entry_computation()->ComputeProgramShape());
   LayoutAssignment layout_assignment(&computation_layout);
-  Status error_status = layout_assignment.Run(module.get()).status();
+  Status error_status = layout_assignment.Run(module).status();
   EXPECT_FALSE(error_status.ok());
   EXPECT_THAT(
       error_status.error_message(),
@@ -839,9 +839,9 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
+  ParseAndVerifyModule(module_str);
   ComputationLayout computation_layout(
-      module->entry_computation()->ComputeProgramShape());
+      module().entry_computation()->ComputeProgramShape());
   Shape param_shape = ShapeUtil::MakeTupleShape(
       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
   TF_ASSERT_OK(
@@ -851,14 +851,13 @@
       LayoutUtil::MakeLayout({1, 0}));
 
   ChannelLayoutConstraints channel_constraints;
-  AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+  AssignLayouts(&module(), &computation_layout, &channel_constraints);
 
-  EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
-  EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
-  EXPECT_TRUE(
-      ShapeUtil::Equal(ShapeUtil::GetSubshape(
-                           FindInstruction(module.get(), "send")->shape(), {0}),
-                       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+  EXPECT_THAT(LayoutOf(&module(), "gte"), ElementsAre(0, 1));
+  EXPECT_THAT(LayoutOf(&module(), "root"), ElementsAre(1, 0));
+  EXPECT_TRUE(ShapeUtil::Equal(
+      ShapeUtil::GetSubshape(FindInstruction(&module(), "send")->shape(), {0}),
+      ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
 }
 
 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
@@ -873,19 +872,19 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
-  module =
+  ParseAndVerifyModule(module_str);
+  auto compiled_module =
       backend()
           .compiler()
-          ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
                          /*device_allocator=*/nullptr)
           .ConsumeValueOrDie();
-
-  auto copy = FindInstruction(module.get(), "copy.1");
-  auto slice = FindInstruction(module.get(), "slice0");
-  EXPECT_EQ(slice->operand(0), copy);
-  EXPECT_TRUE(
-      LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout()));
+  HloInstruction* root =
+      compiled_module->entry_computation()->root_instruction();
+  Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+  EXPECT_THAT(root, op::Add(op::Parameter(),
+                            op::Slice(AllOf(op::Copy(op::Parameter(1)),
+                                            op::ShapeWithLayout(shape_copy)))));
 }
 
 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
@@ -901,19 +900,21 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
-  module =
+  ParseAndVerifyModule(module_str);
+  auto compiled_module =
       backend()
           .compiler()
-          ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
                          /*device_allocator=*/nullptr)
           .ConsumeValueOrDie();
-
-  auto copy = FindInstruction(module.get(), "copy.1");
-  auto dslice = FindInstruction(module.get(), "dslice0");
-  EXPECT_EQ(dslice->operand(0), copy);
-  EXPECT_TRUE(
-      LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout()));
+  HloInstruction* root =
+      compiled_module->entry_computation()->root_instruction();
+  Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
+  EXPECT_THAT(root,
+              op::Add(op::Parameter(),
+                      op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)),
+                                             op::ShapeWithLayout(shape_copy)),
+                                       op::Parameter(2))));
 }
 
 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
@@ -930,19 +931,21 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
-  module =
+  ParseAndVerifyModule(module_str);
+  auto compiled_module =
       backend()
           .compiler()
-          ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
                          /*device_allocator=*/nullptr)
           .ConsumeValueOrDie();
-
-  auto copy = FindInstruction(module.get(), "copy.1");
-  auto concat = FindInstruction(module.get(), "concat0");
-  EXPECT_EQ(concat->operand(0), copy);
-  EXPECT_TRUE(
-      LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout()));
+  HloInstruction* root =
+      compiled_module->entry_computation()->root_instruction();
+  Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
+  EXPECT_THAT(root,
+              op::Add(op::Parameter(),
+                      op::Concatenate(AllOf(op::Copy(op::Parameter(1)),
+                                            op::ShapeWithLayout(shape_copy)),
+                                      op::Parameter(2))));
 }
 
 TEST_F(LayoutAssignmentTest,
@@ -959,16 +962,40 @@
     }
   )";
 
-  auto module = ParseHloString(module_str).ValueOrDie();
-  module =
+  ParseAndVerifyModule(module_str);
+  auto compiled_module =
       backend()
           .compiler()
-          ->RunHloPasses(std::move(module), backend().default_stream_executor(),
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
                          /*device_allocator=*/nullptr)
           .ConsumeValueOrDie();
+  HloInstruction* root =
+      compiled_module->entry_computation()->root_instruction();
+  EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1)));
+}
 
-  auto copy = FindInstruction(module.get(), "copy.1");
-  EXPECT_EQ(copy, nullptr);
+TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
+  const char* module_str = R"(
+    HloModule PropagatingLayoutFromResultToOperand
+
+    ENTRY PropagatingLayoutFromResultToOperand {
+      par0 = f32[4,5]{1,0} parameter(0)
+      ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
+    }
+  )";
+
+  ParseAndVerifyModule(module_str);
+  auto compiled_module =
+      backend()
+          .compiler()
+          ->RunHloPasses(module().Clone(), backend().default_stream_executor(),
+                         /*device_allocator=*/nullptr)
+          .ConsumeValueOrDie();
+  HloInstruction* root =
+      compiled_module->entry_computation()->root_instruction();
+  Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
+  EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)),
+                                    op::ShapeWithLayout(shape_copy))));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 7d49b8d..a60643b 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -75,6 +75,16 @@
   }
 }
 
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+               llvm::IRBuilder<>* b, llvm::Module* module) {
+  std::vector<llvm::Value*> buffer_ptrs;
+  buffer_ptrs.reserve(buffers.size());
+  absl::c_transform(
+      buffers, std::back_inserter(buffer_ptrs),
+      [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
+  llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
+}
+
 llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
                                  int alignment, llvm::Value* operand,
                                  llvm::IRBuilder<>* b, llvm::Module* module) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index 887fb61..94340b9 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -68,6 +68,11 @@
 void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
                llvm::IRBuilder<>* b, llvm::Module* module);
 
+// Similar to EmitTuple above, except that the output buffers are provided in
+// the form of IrArray.
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+               llvm::IRBuilder<>* b, llvm::Module* module);
+
 // A tuple is an array of pointers, one for each operand. Each pointer points to
 // the output buffer of its corresponding operand. A GetTupleElement instruction
 // forwards the pointer to underlying tuple element buffer at the given index.
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index f0e2566..922ebdf 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -68,9 +68,9 @@
   module->clear_arguments();
   for (const ShapedBuffer* argument : arguments) {
     TF_ASSIGN_OR_RETURN(
-        std::unique_ptr<Literal> literal,
+        Literal literal,
         transfer_manager->TransferLiteralFromDevice(stream, *argument));
-    *module->add_arguments() = literal->ToProto();
+    *module->add_arguments() = literal.ToProto();
   }
   return Status::OK();
 }
@@ -80,9 +80,9 @@
                     TransferManager* transfer_manager, HloSnapshot* module) {
   module->clear_result();
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<Literal> literal,
+      Literal literal,
       transfer_manager->TransferLiteralFromDevice(stream, result));
-  *module->mutable_result() = literal->ToProto();
+  *module->mutable_result() = literal.ToProto();
   return Status::OK();
 }
 
@@ -928,16 +928,15 @@
                                        shaped_buffer->device_ordinal()));
 
   TF_ASSIGN_OR_RETURN(
-      std::unique_ptr<Literal> result_literal,
+      Literal result_literal,
       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
           stream.get(), *shaped_buffer));
 
-  if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
-                                       result_literal->shape())) {
-    *result->mutable_literal() = result_literal->ToProto();
+  if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
+    *result->mutable_literal() = result_literal.ToProto();
   } else {
     *result->mutable_literal() =
-        result_literal->Relayout(*return_shape)->ToProto();
+        result_literal.Relayout(*return_shape).ToProto();
   }
   return Status::OK();
 }
@@ -959,9 +958,9 @@
 
 Status Service::TransferToServer(const TransferToServerRequest* arg,
                                  TransferToServerResponse* result) {
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+  TF_ASSIGN_OR_RETURN(Literal literal,
                       Literal::CreateFromProto(arg->literal()));
-  const Shape& shape = literal->shape();
+  const Shape& shape = literal.shape();
 
   std::vector<se::StreamExecutor*> replicas;
   if (arg->has_device_handle()) {
@@ -983,7 +982,7 @@
     TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
     TF_RETURN_IF_ERROR(
         execute_backend_->transfer_manager()->TransferLiteralToDevice(
-            stream.get(), *literal, shaped_buffer));
+            stream.get(), literal, shaped_buffer));
     replicated_buffers.emplace_back(std::move(shaped_buffer));
   }
   TF_ASSIGN_OR_RETURN(*result->mutable_data(),
@@ -1018,10 +1017,10 @@
     executor = replicas[arg->replica_id()];
   }
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+  TF_ASSIGN_OR_RETURN(Literal literal,
                       Literal::CreateFromProto(arg->literal()));
-  return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
-      executor, *literal);
+  return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
+                                                                       literal);
 }
 
 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
@@ -1049,8 +1048,8 @@
 
   TF_RETURN_IF_ERROR(
       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
-          executor, arg->shape_with_layout(), *literal));
-  *result->mutable_literal() = literal->ToProto();
+          executor, arg->shape_with_layout(), literal));
+  *result->mutable_literal() = literal.ToProto();
   return Status::OK();
 }
 
@@ -1085,18 +1084,17 @@
                       HloModule::CreateFromProto(arg->computation(), config));
 
   HloEvaluator evaluator;
-  TF_ASSIGN_OR_RETURN(auto result_literal,
-                      evaluator.Evaluate<std::unique_ptr<Literal>>(
-                          *module, /*arg_literals=*/{}));
+  TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate<Literal>(
+                                               *module, /*arg_literals=*/{}));
 
   // Since the result layout is non-effective to the Evaluator results, explicit
   // relayout here.
   //
   // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
   if (arg->has_output_layout()) {
-    result_literal = result_literal->Relayout(arg->output_layout());
+    result_literal = result_literal.Relayout(arg->output_layout());
   }
-  *result->mutable_literal() = result_literal->ToProto();
+  *result->mutable_literal() = result_literal.ToProto();
 
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2611749..74bdf2a 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1552,8 +1552,8 @@
 }
 
 /* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
-    const Shape& lhs, const Shape& rhs, const Window& window,
-    const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) {
+    const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+    const Window& window, const ConvolutionDimensionNumbers& dnums) {
   TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
   TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
 
@@ -1672,6 +1672,16 @@
         ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
         dnums.DebugString());
   }
+  if (kernel_output_features % feature_group_count > 0) {
+    return InvalidArgument(
+        "Expected output feature dimension (value %d) to be divisible by "
+        "feature_group_count (value %d); "
+        "got <conv>(%s, %s)\n"
+        "Dimension numbers: {%s}.",
+        kernel_output_features, feature_group_count,
+        ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
+        dnums.DebugString());
+  }
   std::vector<int64> window_dims(num_spatial_dims);
   for (int i = 0; i < num_spatial_dims; ++i) {
     window_dims[i] = window.dimensions(i).size();
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index a28345a..96a0ee1 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -108,9 +108,9 @@
   // Infers the shape produced by applying the given convolutional
   // filter (rhs) to lhs in the way specified by the fields on window.
   static StatusOr<Shape> InferConvolveShape(
-      const Shape& lhs, const Shape& rhs, const Window& window,
-      const ConvolutionDimensionNumbers& dimension_numbers,
-      int64 feature_group_count = 1);
+      const Shape& lhs, const Shape& rhs, int64 feature_group_count,
+      const Window& window,
+      const ConvolutionDimensionNumbers& dimension_numbers);
 
   // Infers the shape produced by the given FFT type on the given operand.
   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index cc92e58..864ed43 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -419,8 +419,8 @@
   dim1->set_padding_high(0);
   dim1->set_window_dilation(1);
   dim1->set_base_dilation(1);
-  auto inferred_status =
-      ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+  auto inferred_status = ShapeInference::InferConvolveShape(
+      lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@@ -464,8 +464,8 @@
   dim1->set_padding_high(1);
   dim1->set_window_dilation(2);
   dim1->set_base_dilation(1);
-  auto inferred_status =
-      ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+  auto inferred_status = ShapeInference::InferConvolveShape(
+      lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@@ -509,8 +509,8 @@
   dim1->set_padding_high(1);
   dim1->set_window_dilation(1);
   dim1->set_base_dilation(2);
-  auto inferred_status =
-      ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+  auto inferred_status = ShapeInference::InferConvolveShape(
+      lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
   ASSERT_IS_OK(inferred_status.status());
   Shape inferred_shape = inferred_status.ValueOrDie();
   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@@ -547,8 +547,8 @@
   dim1->set_stride(2);
   dim1->set_padding_low(1);
   dim1->set_padding_high(1);
-  auto inferred_status =
-      ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+  auto inferred_status = ShapeInference::InferConvolveShape(
+      lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums);
   ASSERT_FALSE(inferred_status.ok());
   ASSERT_THAT(inferred_status.status().error_message(),
               HasSubstr("each dimension exactly once"));
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
deleted file mode 100644
index dd53c75..0000000
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/source_map_util.h"
-
-#include "absl/strings/str_format.h"
-#include "tensorflow/compiler/xla/util.h"
-
-namespace xla {
-namespace source_map_util {
-namespace {
-
-Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
-                                 const char* format, va_list args) {
-  string message;
-  tensorflow::strings::Appendv(&message, format, args);
-  if (!op_metadata.source_file().empty()) {
-    absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
-                          op_metadata.source_line());
-  }
-  return InvalidArgument("%s", message);
-}
-
-}  // namespace
-
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
-                                const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  Status result = InvalidParameterArgumentV(op_metadata, format, args);
-  va_end(args);
-  return result;
-}
-
-Status InvalidParameterArgument(Executable* executable, int parameter_number,
-                                const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  if (executable != nullptr && executable->has_module()) {
-    const HloModule& module = executable->module();
-    const HloComputation& computation = *module.entry_computation();
-    HloInstruction* param = computation.parameter_instruction(parameter_number);
-    const OpMetadata& metadata = param->metadata();
-    Status result = InvalidParameterArgumentV(metadata, format, args);
-    va_end(args);
-    return result;
-  }
-  Status result = InvalidArgumentV(format, args);
-  va_end(args);
-  return result;
-}
-
-}  // namespace source_map_util
-}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index b8d2d54..a21e586 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -42,9 +42,9 @@
   return r;
 }
 
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+StatusOr<Literal> TransferManager::TransferLiteralFromDevice(
     se::Stream* stream, const ShapedBuffer& device_buffer) {
-  StatusOr<std::unique_ptr<Literal>> ret;
+  StatusOr<Literal> ret;
 
   se::Stream* substream = stream->GetOrCreateSubStream();
   substream->ThenWaitFor(stream);
@@ -63,7 +63,7 @@
   if (!s.ok()) {
     return s;
   }
-  return absl::make_unique<Literal>(std::move(literal));
+  return std::move(literal);
 }
 
 Status TransferManager::TransferLiteralFromDevice(
@@ -99,10 +99,10 @@
   return substream->BlockHostUntilDone();
 }
 
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+StatusOr<Literal> TransferManager::TransferArrayFromDevice(
     se::Stream* stream, const Shape& shape,
     const se::DeviceMemoryBase& source) {
-  StatusOr<std::unique_ptr<Literal>> ret;
+  StatusOr<Literal> ret;
   // Implement the synchronous version by waiting on the asynchronous version.
   // Use a substream so that if we are called from a HostCallback we don't
   // deadlock.
@@ -122,7 +122,7 @@
   if (!s.ok()) {
     return s;
   }
-  return absl::make_unique<Literal>(std::move(literal));
+  return std::move(literal);
 }
 
 Status TransferManager::TransferArrayToDevice(
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 2172594..f952e64 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -57,7 +57,7 @@
   // without waiting for any other operation on a stream to complete.
   //
   // This function should be avoided in favor of the asynchronous version below.
-  virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
+  virtual StatusOr<Literal> TransferLiteralFromDevice(
       se::Stream* stream, const ShapedBuffer& device_buffer);
   virtual Status TransferLiteralFromDevice(
       se::Stream* stream, const ShapedBuffer& device_buffer,
@@ -113,9 +113,9 @@
   Status TransferArrayToDeviceAsync(se::Stream* stream,
                                     const LiteralSlice& literal,
                                     const se::DeviceMemoryBase& dest);
-  StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
-      se::Stream* stream, const Shape& shape,
-      const se::DeviceMemoryBase& source);
+  StatusOr<Literal> TransferArrayFromDevice(se::Stream* stream,
+                                            const Shape& shape,
+                                            const se::DeviceMemoryBase& source);
 
   // Transfers the given literal into the Infeed interface of the device,
   // using the given executor.
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 530f40e..7c1f4b5 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -108,8 +108,7 @@
   }
 
   std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
-      dot->shape(), new_lhs, new_rhs, new_dim_numbers);
-  new_dot->set_precision_config(dot->precision_config());
+      dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config());
   return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
 }
 
@@ -178,8 +177,8 @@
   }
 
   auto new_conv = HloInstruction::CreateConvolve(
-      convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
-  new_conv->set_precision_config(convolution.precision_config());
+      convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(),
+      convolution.window(), new_dnums, convolution.precision_config());
   TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
       &convolution, std::move(new_conv)));
 
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 58f767e..79b5c09 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -240,10 +240,12 @@
         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
-      x->shape(), transpose_y->shape(), window, dnums);
+      x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+      dnums);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+      conv_shape.ValueOrDie(), x, transpose_y,
+      /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
@@ -293,10 +295,12 @@
         transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
-      x->shape(), transpose_y->shape(), window, dnums);
+      x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window,
+      dnums);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      conv_shape.ValueOrDie(), x, transpose_y, window, dnums));
+      conv_shape.ValueOrDie(), x, transpose_y,
+      /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
@@ -351,10 +355,12 @@
     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
-      transpose_x->shape(), y->shape(), window, dnums);
+      transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+      dnums);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+      conv_shape.ValueOrDie(), transpose_x, y,
+      /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
@@ -415,10 +421,12 @@
     dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
   }
   StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
-      transpose_x->shape(), y->shape(), window, dnums);
+      transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window,
+      dnums);
   EXPECT_IS_OK(conv_shape);
   HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
-      conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+      conv_shape.ValueOrDie(), transpose_x, y,
+      /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
 
   auto module = CreateNewModule("test_module");
   HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a32d1f9..e9a07b1 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -555,10 +555,10 @@
   // Construct a tuple constant and kCopy it. Verify the points-to set of the
   // copy correctly correctly points into the nested elements of the constant.
   auto builder = HloComputation::Builder(TestName());
-  auto tuple_constant = builder.AddInstruction(
-      HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
-          {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
-           LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+  Literal elements[] = {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+                        LiteralUtil::CreateR1<float>({2.0, 42})};
+  auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+      LiteralUtil::MakeTuple({&elements[0], &elements[1]})));
   auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
       tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
 
@@ -1064,8 +1064,11 @@
   DotDimensionNumbers dot_dnums;
   dot_dnums.add_lhs_contracting_dimensions(1);
   dot_dnums.add_rhs_contracting_dimensions(0);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      /*new_size=*/2, PrecisionConfig::DEFAULT);
   auto dot = builder.AddInstruction(
-      HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
+      HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
 
   auto one = builder.AddInstruction(
       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index 39b6938..516754e 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -25,7 +25,7 @@
 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 
@@ -34,7 +34,7 @@
 namespace xla {
 namespace {
 
-class TupleSimplifierTest : public HloTestBase {
+class TupleSimplifierTest : public HloVerifiedTestBase {
  protected:
   void Run(HloModule* module, bool change_expected) {
     TupleSimplifier simplifier;
@@ -68,7 +68,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  Run(module.get(), /*change_expected=*/false);
+  Run(module, /*change_expected=*/false);
 }
 
 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
@@ -81,7 +81,7 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(builder.Build());
 
-  Run(module.get(), /*change_expected=*/false);
+  Run(module, /*change_expected=*/false);
 }
 
 TEST_F(TupleSimplifierTest, GteOfTuple) {
@@ -103,7 +103,7 @@
 
   EXPECT_THAT(computation->root_instruction(), gte);
 
-  Run(module.get(), /*change_expected=*/true);
+  Run(module, /*change_expected=*/true);
 
   EXPECT_THAT(computation->root_instruction(), param1);
 }
@@ -131,7 +131,7 @@
   EXPECT_THAT(computation->root_instruction(),
               op::Negate(op::GetTupleElement(op::Tuple())));
 
-  Run(module.get(), /*change_expected=*/true);
+  Run(module, /*change_expected=*/true);
 
   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
 }
@@ -162,7 +162,7 @@
 
   EXPECT_THAT(computation->root_instruction(), element);
 
-  Run(module.get(), /*change_expected=*/true);
+  Run(module, /*change_expected=*/true);
 
   EXPECT_THAT(computation->root_instruction(), param);
 }
@@ -187,7 +187,7 @@
 
   EXPECT_THAT(computation->root_instruction(), tuple);
 
-  Run(module.get(), /*change_expected=*/true);
+  Run(module, /*change_expected=*/true);
 
   EXPECT_THAT(computation->root_instruction(), tuple_param);
 }
@@ -212,7 +212,7 @@
 
   EXPECT_THAT(computation->root_instruction(), tuple);
 
-  Run(module.get(), /*change_expected=*/false);
+  Run(module, /*change_expected=*/false);
 
   EXPECT_THAT(computation->root_instruction(), tuple);
 }
@@ -281,7 +281,7 @@
     entry = module->AddEntryComputation(builder.Build());
   }
 
-  Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+  Run(module, /*change_expected=*/true, /*exclude_entry=*/true);
 
   EXPECT_THAT(c0->root_instruction(), p0);
   EXPECT_THAT(c1->root_instruction(), p1);
diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc
index c3c2603..541b117 100644
--- a/tensorflow/compiler/xla/service/while_loop_analysis.cc
+++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc
@@ -183,8 +183,7 @@
   HloEvaluator evaluator(/*max_loop_iterations=*/0);
   auto* while_init = while_op->mutable_operand(0);
   auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx);
-  StatusOr<std::unique_ptr<Literal>> indvar_init_result =
-      evaluator.Evaluate(indvar_init);
+  StatusOr<Literal> indvar_init_result = evaluator.Evaluate(indvar_init);
   if (!indvar_init_result.ok()) {
     VLOG(2) << "Couldn't evaluate induction variable init: "
             << indvar_init_result.status();
@@ -197,31 +196,27 @@
   auto* while_body_indvar = NonConstantOperand(while_body_indvar_update);
 
   // The initial value of the induction variable.
-  std::unique_ptr<Literal> indvar_iter_val =
-      std::move(indvar_init_result).ValueOrDie();
+  Literal indvar_iter_val = std::move(indvar_init_result).ValueOrDie();
   for (int64 trip_count = 0; trip_count != max_value_returned + 1;
        ++trip_count) {
     auto* while_cond = while_op->while_condition();
     auto* while_cond_root = while_cond->root_instruction();
     auto* while_cond_indvar = NonConstantOperand(while_cond_root);
-    StatusOr<std::unique_ptr<Literal>> result =
-        evaluator.EvaluateWithSubstitutions(
-            while_cond_root, {{while_cond_indvar, indvar_iter_val.get()}});
+    StatusOr<Literal> result = evaluator.EvaluateWithSubstitutions(
+        while_cond_root, {{while_cond_indvar, &indvar_iter_val}});
     if (!result.ok()) {
       VLOG(2) << "Couldn't evaluate while cond: " << result.status();
       return nullopt;
     }
-    if (result.ValueOrDie()->data<bool>() == absl::Span<const bool>{false}) {
+    if (result.ValueOrDie().data<bool>() == absl::Span<const bool>{false}) {
       VLOG(2) << "Loop has static trip count of " << trip_count;
       return trip_count;
     }
 
     // Calculate the value of the induction variable after one iteration of the
     // loop, and check whether the while condition is true with this new value.
-    StatusOr<std::unique_ptr<Literal>> indvar_next_result =
-        evaluator.EvaluateWithSubstitutions(
-            while_body_indvar_update,
-            {{while_body_indvar, indvar_iter_val.get()}});
+    StatusOr<Literal> indvar_next_result = evaluator.EvaluateWithSubstitutions(
+        while_body_indvar_update, {{while_body_indvar, &indvar_iter_val}});
     if (!indvar_next_result.ok()) {
       VLOG(2) << "Couldn't evaluate induction variable update: "
               << indvar_next_result.status();
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index aab1180..5614582 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -15,10 +15,10 @@
 
 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
 #include "absl/algorithm/container.h"
+#include "absl/container/inlined_vector.h"
 #include "tensorflow/compiler/xla/service/while_util.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
 
 namespace xla {
 
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 52c895e..df61010 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -224,14 +224,13 @@
   // REQUIRES: index must exist in the ShapeTree.
   iterator find(ShapeIndexView index) {
     Node* element = Lookup(index);
-    return iterator(&nodes_, typename std::vector<Node>::iterator(element),
-                    /*iterate_leaves_only=*/false);
+    auto element_iter = nodes_.begin() + (element - &nodes_[0]);
+    return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
   }
   const_iterator find(ShapeIndexView index) const {
     Node* element = Lookup(index);
-    return iterator(&nodes_,
-                    typename std::vector<Node>::const_iterator(element),
-                    /*iterate_leaves_only=*/false);
+    auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
+    return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
   }
 
   // Returns the number of leaf nodes in the tree.
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9772c06..96c80fd 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -441,6 +441,19 @@
   return count;
 }
 
+/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
+                                              PrimitiveType primitive_type) {
+  if (shape.element_type() == primitive_type) {
+    return true;
+  }
+  for (const Shape& element_shape : shape.tuple_shapes()) {
+    if (HasPrimitiveType(element_shape, primitive_type)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 /* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
   return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
 }
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 8234fcd..623ae39 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -180,6 +180,10 @@
   // As ElementsIn(), but recurses through tuples.
   static int64 ElementsInRecursive(const Shape& shape);
 
+  // Returns true if shape has the primitive type, recurses through tuples.
+  static bool HasPrimitiveType(const Shape& shape,
+                               PrimitiveType primitive_type);
+
   // Returns true if 'shape' is an array with zero elements.
   static bool IsZeroElementArray(const Shape& shape);
 
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 6ca4085..c622ecd 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -445,6 +445,22 @@
   EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
 }
 
+TEST(ShapeUtilTest, HasPrimitiveType) {
+  EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
+  EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
+  EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
+  EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
+  EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+      ShapeUtil::MakeTupleShape(
+          {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
+      S32));
+  EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+      ShapeUtil::MakeTupleShape(
+          {ShapeUtil::MakeShape(S32, {}),
+           ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
+      S16));
+}
+
 TEST(ShapeUtilTest, IsZeroElementArray) {
   EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
   EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 36b8fb2..30e3077 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -75,7 +75,6 @@
         "//tensorflow/compiler/xla/service:hlo_verifier",
         "//tensorflow/compiler/xla/service:transfer_manager",
         "//tensorflow/core:lib",
-        "//tensorflow/core:stream_executor_headers_lib",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
@@ -648,6 +647,7 @@
     ],
     shard_count = 48,
     tags = [
+        "broken",
         "manual",
         "notap",
     ],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 0bf4556..c257566 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -41,7 +41,6 @@
 namespace xla {
 namespace {
 
-
 class ArrayElementwiseOpTest : public ClientLibraryTestBase {
  public:
   ErrorSpec error_spec_{0.0001, 0.0001};
@@ -227,10 +226,10 @@
                           0x8000000000000000LL,
                           0x8000000000000000LL,
                           1};
-  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
-  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+  Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
   std::unique_ptr<GlobalData> lhs_data =
-      client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+      client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
 
   std::vector<uint64> rhs{1,
                           0x7FFFFFFFFFFFFFFLL,
@@ -241,10 +240,10 @@
                           0,
                           1,
                           0x8000000000000000LL};
-  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
-  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+  Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
   std::unique_ptr<GlobalData> rhs_data =
-      client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+      client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
 
   Add(lhs_param, rhs_param);
 
@@ -267,10 +266,10 @@
                          1,
                          0,
                          -1};
-  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
-  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+  Literal lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
   std::unique_ptr<GlobalData> lhs_data =
-      client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
+      client_->TransferToServer(lhs_literal).ConsumeValueOrDie();
 
   std::vector<int64> rhs{-1,
                          0,
@@ -280,10 +279,10 @@
                          0x7FFFFFFFFFFFFFFLL,
                          0x7FFFFFFFFFFFFFFFLL,
                          0x7FFFFFFFFFFFFFFFLL};
-  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
-  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+  Literal rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
   std::unique_ptr<GlobalData> rhs_data =
-      client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
+      client_->TransferToServer(rhs_literal).ConsumeValueOrDie();
 
   Sub(lhs_param, rhs_param);
 
@@ -299,16 +298,16 @@
   XlaBuilder b(TestName());
 
   std::vector<uint64> lhs{static_cast<uint64>(0x8000000000000000ULL)};
-  std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
-  auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
+  Literal lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+  auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param");
 
   std::vector<uint64> rhs{static_cast<uint64>(0x7FFFFFFFFFFFFFFFULL)};
-  std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
-  auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
+  Literal rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+  auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param");
 
   Lt(lhs_param, rhs_param);
 
-  ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)});
+  ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)});
 }
 
 TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
@@ -321,16 +320,16 @@
     b_values.push_back(2 * i / static_cast<float>(count + 2));
   }
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+  Literal a_literal = LiteralUtil::CreateR1<float>({a_values});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
   auto a_constant = ConstantR1<float>(&builder, a_values);
-  auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
+  auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param");
 
-  std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+  Literal b_literal = LiteralUtil::CreateR1<float>({b_values});
   std::unique_ptr<GlobalData> b_data =
-      client_->TransferToServer(*b_literal).ConsumeValueOrDie();
-  auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+      client_->TransferToServer(b_literal).ConsumeValueOrDie();
+  auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param");
   auto b_param = ConstantR1<float>(&builder, b_values);
 
   auto sum1 = Add(a_constant, b_constant);
@@ -1422,12 +1421,12 @@
   std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
   std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 
-  std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+  Literal param_literal = LiteralUtil::CreateR1<float>(values);
   std::unique_ptr<GlobalData> param_data =
-      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param_literal).ConsumeValueOrDie();
 
   auto sum = ConstantR0<float>(&b, 0.0f);
-  auto param = Parameter(&b, 0, param_literal->shape(), "param");
+  auto param = Parameter(&b, 0, param_literal.shape(), "param");
   for (float exponent : exponents) {
     sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
   }
@@ -1450,14 +1449,14 @@
   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
   Pow(Exp(param0), param1);
 
   std::vector<float> expected(values0.size());
@@ -1475,14 +1474,14 @@
   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
   Log(Pow(param0, param1));
 
   std::vector<float> expected(values0.size());
@@ -1500,14 +1499,14 @@
   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
   Mul(Exp(param0), Exp(param1));
 
   std::vector<float> expected(values0.size());
@@ -1525,14 +1524,14 @@
   std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
   std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
   Div(param0, Exp(param1));
 
   std::vector<float> expected(values0.size());
@@ -1551,20 +1550,20 @@
   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
   std::unique_ptr<GlobalData> data2 =
-      client_->TransferToServer(*literal2).ConsumeValueOrDie();
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
-  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+      client_->TransferToServer(literal2).ConsumeValueOrDie();
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
   Div(Div(param0, param1), param2);
 
   std::vector<float> expected(values0.size());
@@ -1583,21 +1582,21 @@
   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
   std::unique_ptr<GlobalData> data2 =
-      client_->TransferToServer(*literal2).ConsumeValueOrDie();
+      client_->TransferToServer(literal2).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
-  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
   Div(param0, Div(param1, param2));
 
   std::vector<float> expected(values0.size());
@@ -1616,21 +1615,21 @@
   std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
   std::unique_ptr<GlobalData> data2 =
-      client_->TransferToServer(*literal2).ConsumeValueOrDie();
+      client_->TransferToServer(literal2).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
-  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
   Div(param0, Pow(param1, param2));
 
   std::vector<float> expected(values0.size());
@@ -1650,26 +1649,26 @@
   std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
   std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
+  Literal literal0 = LiteralUtil::CreateR1<float>(values0);
   std::unique_ptr<GlobalData> data0 =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
+  Literal literal1 = LiteralUtil::CreateR1<float>(values1);
   std::unique_ptr<GlobalData> data1 =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
+  Literal literal2 = LiteralUtil::CreateR1<float>(values2);
   std::unique_ptr<GlobalData> data2 =
-      client_->TransferToServer(*literal2).ConsumeValueOrDie();
+      client_->TransferToServer(literal2).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
+  Literal literal3 = LiteralUtil::CreateR1<float>(values3);
   std::unique_ptr<GlobalData> data3 =
-      client_->TransferToServer(*literal3).ConsumeValueOrDie();
+      client_->TransferToServer(literal3).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
-  auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
-  auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+  auto param0 = Parameter(&b, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&b, 1, literal1.shape(), "param1");
+  auto param2 = Parameter(&b, 2, literal2.shape(), "param2");
+  auto param3 = Parameter(&b, 3, literal3.shape(), "param2");
   Div(Div(param0, param1), Div(param2, param3));
 
   std::vector<float> expected(values0.size());
@@ -2096,18 +2095,18 @@
 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> param1_literal =
+  Literal param1_literal =
       LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Add(p0, p1);
 
   ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
@@ -2118,18 +2117,18 @@
 XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> param1_literal =
+  Literal param1_literal =
       LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Add(p0, p1);
 
   Array3D<float> expected(0, 7, 0);
@@ -2140,13 +2139,13 @@
 XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
-  auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto p = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Add(a, p);
 
   ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
@@ -2206,9 +2205,9 @@
        0.08,  -1.24, -0.92, 0.49,  1.17,  -0.45, -1.31, -1.44, -0.13, -1.31,
        -0.79, 1.41,  1.21,  1.05});
   TF_ASSERT_OK_AND_ASSIGN(auto input_data,
-                          client_->TransferToServer(*input_literal));
+                          client_->TransferToServer(input_literal));
 
-  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
   Tanh(input);
 
   ComputeAndCompareR1<float>(
@@ -2239,7 +2238,7 @@
 
   // Just to help make sense of the scales here -- exp(89) saturates float32 and
   // exp(-10) is smaller than our error spec.
-  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+  Literal input_literal = LiteralUtil::CreateR1<float>(
       {1.02,   -0.32,  0.85,   0.9,    1.23,   -0.91,  -0.49, 0.8,    -1.31,
        -1.44,  -0.13,  -1.31,  -0.79,  1.41,   1.21,   1.05,  -195.6, -194.5,
        -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5,  -17.4,
@@ -2252,16 +2251,16 @@
        78.3,   79.4,   80.5,   81.6,   82.7,   83.8,   84.9,  85.2,   86.3,
        86.4,   86.5,   87.6,   87.7,   87.8,   87.9});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
-                          client_->TransferToServer(*input_literal));
+                          client_->TransferToServer(input_literal));
 
-  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
   Exp(input);
 
   std::vector<float> expected_result;
-  int64 input_size = input_literal->shape().dimensions(0);
+  int64 input_size = input_literal.shape().dimensions(0);
   expected_result.reserve(input_size);
   for (int64 i = 0; i < input_size; i++) {
-    expected_result.push_back(std::exp(input_literal->Get<float>({i})));
+    expected_result.push_back(std::exp(input_literal.Get<float>({i})));
   }
 
   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2273,7 +2272,7 @@
   // implementation on XLA CPU.
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
+  Literal input_literal = LiteralUtil::CreateR1<float>(
       {-1.29,    -1.41,    -1.25,    -13.5,    -11.7,    -17.9,    -198,
        -167,     1.29,     1.41,     1.25,     13.5,     11.7,     17.9,
        198,      167,      1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04,  1.84e+04,
@@ -2290,16 +2289,16 @@
        1.7e+31,  1.44e+31, 1.1e+31,  1.4e+32,  1.67e+32, 1.96e+33, 1.11e+33,
        1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
-                          client_->TransferToServer(*input_literal));
+                          client_->TransferToServer(input_literal));
 
-  auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+  auto input = Parameter(&builder, 0, input_literal.shape(), "input");
   Log(input);
 
   std::vector<float> expected_result;
-  int64 input_size = input_literal->shape().dimensions(0);
+  int64 input_size = input_literal.shape().dimensions(0);
   expected_result.reserve(input_size);
   for (int64 i = 0; i < input_size; i++) {
-    expected_result.push_back(std::log(input_literal->Get<float>({i})));
+    expected_result.push_back(std::log(input_literal.Get<float>({i})));
   }
 
   ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
@@ -2465,10 +2464,10 @@
   auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
   Tuple(&builder, {cmp_dim_0, cmp_dim_1});
 
-  auto expected = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
-       LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}),
+       LiteralUtil::CreateR2<bool>({{true, false}, {false, false}})});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
@@ -2821,10 +2820,9 @@
   std::iota(r1.begin(), r1.end(), 1.0);
 
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> a_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
-  auto a = ConstantLiteral(&builder, *a_literal);
+  Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+  auto a = ConstantLiteral(&builder, a_literal);
   auto b = ConstantR1<float>(&builder, r1);
   Add(a, b, {1});
 
@@ -2886,11 +2884,11 @@
   XlaBuilder builder(TestName());
   auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
   auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
-  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
-  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 
-  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
-  auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
+  auto y = Parameter(&builder, 1, y_literal.shape(), "y");
   auto slice = Slice(x, {1}, {2}, {1});
   Sub(slice, y);
 
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index ac90a3a..bc2ba15 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -63,7 +63,7 @@
         {5.0f, 4.4f},   // p2
     });
     input_array_.FillWithPZ(pz);
-    input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
+    input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
     CHECK_EQ(kSamples, input_array_.planes());
     CHECK_EQ(kZ, input_array_.depth());
     CHECK_EQ(kY, input_array_.height());
@@ -242,14 +242,13 @@
   BatchNormTraining(operand, scale, offset,
                     /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
-                                     {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
-           .get(),
-       LiteralUtil::CreateR1<float>({4, 5}).get(),
-       LiteralUtil::CreateR1<float>({5, 5}).get()});
+                                     {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
+       LiteralUtil::CreateR1<float>({4, 5}),
+       LiteralUtil::CreateR1<float>({5, 5})});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 }
 
 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
@@ -267,14 +266,13 @@
   BatchNormTraining(operand, scale, offset,
                     /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
-                                     {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
-           .get(),
-       LiteralUtil::CreateR1<float>({4, 5}).get(),
-       LiteralUtil::CreateR1<float>({5, 5}).get()});
+                                     {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
+       LiteralUtil::CreateR1<float>({4, 5}),
+       LiteralUtil::CreateR1<float>({5, 5})});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 }
 
 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
@@ -298,13 +296,12 @@
   BatchNormTraining(h0, h1, h2,
                     /*epsilon=*/1, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
-           .get(),
-       LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
-       LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
+       LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
+       LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
 
-  ComputeAndCompareTuple(&builder, *expected,
+  ComputeAndCompareTuple(&builder, expected,
                          {operand.get(), scale.get(), offset.get()},
                          ErrorSpec(0.1));
 }
@@ -331,14 +328,13 @@
   BatchNormTraining(h0, h1, h2,
                     /*epsilon=*/-100, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR3FromArray3D<float>(
-           {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
-           .get(),
-       LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
-       LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+           {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
+       LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
+       LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
 
-  ComputeAndCompareTuple(&builder, *expected,
+  ComputeAndCompareTuple(&builder, expected,
                          {operand.get(), scale.get(), offset.get()},
                          ErrorSpec(0.1));
 }
@@ -363,14 +359,13 @@
   BatchNormGrad(operand, scale, mean, var, grad_output,
                 /*epsilon=*/0.0, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
-                                     {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
-           .get(),
-       LiteralUtil::CreateR1<float>({0, 0}).get(),
-       LiteralUtil::CreateR1<float>({16, 20}).get()});
+                                     {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
+       LiteralUtil::CreateR1<float>({0, 0}),
+       LiteralUtil::CreateR1<float>({16, 20})});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 }
 
 struct BatchNormTestParam {
@@ -522,22 +517,22 @@
   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
 
   auto input_activations =
-      Parameter(&builder, 0, input_literal->shape(), "input");
+      Parameter(&builder, 0, input_literal.shape(), "input");
   auto scale_activations =
-      Parameter(&builder, 1, scale_literal->shape(), "offset");
+      Parameter(&builder, 1, scale_literal.shape(), "offset");
   auto offset_activations =
-      Parameter(&builder, 2, offset_literal->shape(), "scale");
+      Parameter(&builder, 2, offset_literal.shape(), "scale");
 
-  auto expected = LiteralUtil::MakeTuple(
-      {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
-       LiteralUtil::CreateR1<float>(var).get()});
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {expected_normalized, LiteralUtil::CreateR1<float>(mean),
+       LiteralUtil::CreateR1<float>(var)});
 
   std::unique_ptr<GlobalData> input_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> scale_data =
-      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> offset_data =
-      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+      client_->TransferToServer(offset_literal).ConsumeValueOrDie();
 
   BatchNormTraining(input_activations, scale_activations, offset_activations,
                     epsilon, feature_index);
@@ -547,7 +542,7 @@
   // testcase.
   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
   ComputeAndCompareTuple(
-      &builder, *expected,
+      &builder, expected,
       {input_data.get(), scale_data.get(), offset_data.get()},
       ErrorSpec(0.01, 1));
 }
@@ -622,27 +617,27 @@
   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
 
   auto input_activations =
-      Parameter(&builder, 0, input_literal->shape(), "input");
+      Parameter(&builder, 0, input_literal.shape(), "input");
   auto scale_activations =
-      Parameter(&builder, 1, scale_literal->shape(), "offset");
+      Parameter(&builder, 1, scale_literal.shape(), "offset");
   auto offset_activations =
-      Parameter(&builder, 2, offset_literal->shape(), "scale");
-  auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
+      Parameter(&builder, 2, offset_literal.shape(), "scale");
+  auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
   auto variance_activations =
-      Parameter(&builder, 4, var_literal->shape(), "variance");
+      Parameter(&builder, 4, var_literal.shape(), "variance");
 
   Array4D<float> expected = normalized;
 
   std::unique_ptr<GlobalData> input_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> scale_data =
-      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> offset_data =
-      client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
+      client_->TransferToServer(offset_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> mean_data =
-      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+      client_->TransferToServer(mean_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> variance_data =
-      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+      client_->TransferToServer(var_literal).ConsumeValueOrDie();
 
   BatchNormInference(input_activations, scale_activations, offset_activations,
                      mean_activations, variance_activations, epsilon,
@@ -811,40 +806,37 @@
   auto grad_output_literal =
       LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
 
-  auto input_parameter =
-      Parameter(&builder, 0, input_literal->shape(), "input");
-  auto scale_parameter =
-      Parameter(&builder, 1, scale_literal->shape(), "scale");
-  auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
-  auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
+  auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
+  auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
+  auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
+  auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
   auto grad_output_parameter =
-      Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
+      Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
 
   std::unique_ptr<GlobalData> input_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> scale_data =
-      client_->TransferToServer(*scale_literal).ConsumeValueOrDie();
+      client_->TransferToServer(scale_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> mean_data =
-      client_->TransferToServer(*mean_literal).ConsumeValueOrDie();
+      client_->TransferToServer(mean_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> var_data =
-      client_->TransferToServer(*var_literal).ConsumeValueOrDie();
+      client_->TransferToServer(var_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> grad_output_data =
-      client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
+      client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
 
   BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
                 grad_output_parameter, epsilon, feature_index);
 
-  auto expected =
-      LiteralUtil::MakeTuple({expected_grad_activation.get(),
-                              LiteralUtil::CreateR1<float>(grad_scale).get(),
-                              LiteralUtil::CreateR1<float>(grad_offset).get()});
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
+       LiteralUtil::CreateR1<float>(grad_offset)});
 
   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
   // disables constant folding, but we want it enabled for our zero-sized tensor
   // testcase.
   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
 
-  ComputeAndCompareTuple(&builder, *expected,
+  ComputeAndCompareTuple(&builder, expected,
                          {input_data.get(), scale_data.get(), mean_data.get(),
                           var_data.get(), grad_output_data.get()},
                          ErrorSpec(0.01, 1));
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index 65589b0..e9728e6 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -95,22 +95,19 @@
 
   BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR4<bfloat16>(
            {{{{static_cast<bfloat16>(-1.6875f)},
               {static_cast<bfloat16>(-2.04f)}},
              {{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
             {{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
-             {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
-           .get(),
+             {{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}}),
        LiteralUtil::CreateR1<bfloat16>(
-           {static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
-           .get(),
+           {static_cast<bfloat16>(4), static_cast<bfloat16>(5)}),
        LiteralUtil::CreateR1<bfloat16>(
-           {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
-           .get()});
+           {static_cast<bfloat16>(5), static_cast<bfloat16>(5)})});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01, 0.02));
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01, 0.02));
 }
 
 XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
@@ -139,21 +136,18 @@
   BatchNormGrad(operand, scale, mean, var, grad_output,
                 /*epsilon=*/0.0, kFeatureIndex);
 
-  auto expected = LiteralUtil::MakeTuple(
+  auto expected = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR4<bfloat16>(
            {{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
              {{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
             {{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
-             {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
-           .get(),
+             {{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}}),
        LiteralUtil::CreateR1<bfloat16>(
-           {static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
-           .get(),
+           {static_cast<bfloat16>(0), static_cast<bfloat16>(0)}),
        LiteralUtil::CreateR1<bfloat16>(
-           {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
-           .get()});
+           {static_cast<bfloat16>(16), static_cast<bfloat16>(20)})});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01));
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index fe4267c..dde19fb 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -60,10 +60,10 @@
                                          float end, int seed) {
     *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
     r3_array->FillRandom(start, end, seed);
-    auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+    auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
         LayoutUtil::MakeLayout(minor_to_major));
     std::unique_ptr<GlobalData> r3_global_data =
-        client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+        client_->TransferToServer(r3_data).ConsumeValueOrDie();
     return r3_global_data;
   }
 
@@ -74,10 +74,10 @@
                                          float end, int seed) {
     *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
     r2_array->FillRandom(start, end, seed);
-    auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+    auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
         LayoutUtil::MakeLayout(minor_to_major));
     std::unique_ptr<GlobalData> r2_global_data =
-        client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+        client_->TransferToServer(r2_data).ConsumeValueOrDie();
     return r2_global_data;
   }
 
@@ -293,7 +293,7 @@
   XlaBuilder b(TestName());
 
   Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
-      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+      ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
       /*broadcast_dimensions=*/{1, 2});
 
@@ -301,7 +301,7 @@
       LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
                                     {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 struct R3ImplicitBroadcastSpec {
@@ -370,8 +370,7 @@
   }
   auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
   ComputeAndCompareLiteral(
-      &builder, *expected,
-      {r3_implicit_global_data.get(), r3_global_data.get()},
+      &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
       ErrorSpec(1e-7, 1e-7));
 }
 
@@ -395,89 +394,89 @@
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+  ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
                            ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
   XlaBuilder b(TestName());
-  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
   XlaBuilder b(TestName());
-  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
   XlaBuilder b(TestName());
   auto r1 =
-      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+      ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
   XlaBuilder b(TestName());
   auto r1 =
-      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+      ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
   XlaBuilder b(TestName());
   auto r1 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
   XlaBuilder b(TestName());
-  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1);
 
   auto expected =
       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 struct R2ImplicitBroadcastSpec {
@@ -618,7 +617,7 @@
 
   auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
   ComputeAndCompareLiteral(
-      &builder, *expected,
+      &builder, expected,
       {r2_implicit_global_data1.get(), r2_global_data.get(),
        r2_implicit_global_data2.get()},
       ErrorSpec(1e-6, 1e-6));
@@ -630,65 +629,63 @@
 
 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
   XlaBuilder b(TestName());
-  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
-  auto r2 =
-      ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+  auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
   Add(r2, r1);
 
   auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
   XlaBuilder b(TestName());
-  auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
-  auto r2 =
-      ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+  auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+  auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
   Add(r2, r1);
 
   auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
   XlaBuilder b(TestName());
   auto r1 = ConstantR1<float>(&b, {10, 20});
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r3, r1, {0});
 
   auto expected = LiteralUtil::CreateR3<float>(
       {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
   XlaBuilder b(TestName());
   auto r1 = ConstantR1<float>(&b, {10, 20});
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r1, r3, {1});
 
   auto expected = LiteralUtil::CreateR3<float>(
       {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
   XlaBuilder b(TestName());
   auto r1 = ConstantR1<float>(&b, {10, 20});
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   Add(r1, r3, {2});
 
   auto expected = LiteralUtil::CreateR3<float>(
       {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -697,7 +694,7 @@
   auto r1_1 = ConstantR1<float>(&b, {100, 200});
   auto r1_2 = ConstantR1<float>(&b, {10, 20});
   auto r3 = ConstantLiteral(
-      &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+      &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
   for (int i = 0; i < 3; ++i) {
     r3 = Add(r1_0, r3, {0});
     r3 = Add(r3, r1_1, {1});
@@ -709,7 +706,7 @@
       {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
        {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -730,7 +727,7 @@
       {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
        {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
 
-  ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+  ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
 }
 
 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -739,7 +736,7 @@
   XlaBuilder b(TestName());
 
   Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
-      ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+      ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
       /*broadcast_dimensions=*/{1, 2});
 
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 74d4d2e..9966e46 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -46,8 +46,8 @@
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
-                                    *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
+                                    error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -63,7 +63,7 @@
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+      LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
       error_spec_));
 }
 
@@ -86,12 +86,12 @@
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
-      LiteralSlice(*result, {0}), error_spec_));
+      LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+      LiteralSlice(result, {0}), error_spec_));
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
-      LiteralSlice(*result, {1}), error_spec_));
+      LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+      LiteralSlice(result, {1}), error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -107,7 +107,7 @@
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
       error_spec_));
 }
 
@@ -126,7 +126,7 @@
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+      LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
       error_spec_));
 }
 
@@ -143,9 +143,9 @@
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
-                                     {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
-      *result, error_spec_));
+      LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+                                    {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+      result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,9 +166,8 @@
   Array2D<float> pz({{1, 2}, {1, 2}});
   expected.FillWithPZ(pz);
 
-  EXPECT_TRUE(
-      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
-                            *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -197,9 +196,8 @@
   }
   expected.FillWithYX(yx);
 
-  EXPECT_TRUE(
-      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
-                            *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -220,8 +218,8 @@
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
-                                    *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
+                                    result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -240,9 +238,8 @@
   Array4D<float> expected(64, 64, 3, 3);
   expected.Fill(1.0f);
 
-  EXPECT_TRUE(
-      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
-                            *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -263,9 +260,8 @@
   Array4D<float> expected(3, 3, 2, 2);
   expected.FillWithYX(to_broadcast);
 
-  EXPECT_TRUE(
-      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
-                            *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -295,9 +291,8 @@
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  EXPECT_TRUE(
-      LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
-                            *result, error_spec_));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index b1d1821..8b31e53 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -77,8 +77,7 @@
 XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
   XlaBuilder builder(TestName());
   XlaComputation callee = CreateR0F32IdentityComputation();
-  auto constant =
-      ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+  auto constant = ConstantLiteral(&builder, LiteralUtil::CreateR0<float>(42.0));
   Call(&builder, callee, {constant});
 
   ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
@@ -87,8 +86,8 @@
 XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
   XlaBuilder builder(TestName());
   XlaComputation callee = CreateR1S0F32AdditionComputation();
-  auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
-  auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+  auto x = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
+  auto y = ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({}));
   Call(&builder, callee, {x, y});
 
   ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
@@ -98,9 +97,9 @@
   XlaBuilder builder(TestName());
   XlaComputation callee = CreateR1S2F32AdditionComputation();
   auto x =
-      ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
   auto y =
-      ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+      ConstantLiteral(&builder, LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
   Call(&builder, callee, {x, y});
 
   ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
@@ -133,7 +132,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> start,
-      client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
+      client_->TransferToServer(LiteralUtil::CreateR0<float>(1.0f)));
   ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
 }
 
@@ -141,10 +140,10 @@
   XlaBuilder builder(TestName());
   XlaComputation callee = CreateR0F32TupleComputation();
   auto elem = LiteralUtil::CreateR0<float>(42.0);
-  auto tuple = LiteralUtil::MakeTuple({elem.get()});
-  Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
+  auto tuple = LiteralUtil::MakeTuple({&elem});
+  Call(&builder, callee, {ConstantLiteral(&builder, elem)});
 
-  ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+  ComputeAndCompareTuple(&builder, tuple, {}, ErrorSpec(0.01f));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index a4eb57f..2f1510f 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -38,14 +38,14 @@
   XlaBuilder builder("add_two_params");
   auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
 
-  auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
-  auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+  auto p0 = Parameter(&builder, 0, param_literal.shape(), "param0");
+  auto p1 = Parameter(&builder, 1, param_literal.shape(), "param1");
   Add(p0, p1);
 
   auto param0_data =
-      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param_literal).ConsumeValueOrDie();
   auto param1_data =
-      client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param_literal).ConsumeValueOrDie();
 
   auto computation_status = builder.Build();
   ASSERT_IS_OK(computation_status.status());
@@ -86,12 +86,12 @@
   auto computation = computation_status.ConsumeValueOrDie();
 
   auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
-  auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+  auto f32_data = client_->TransferToServer(f32_literal).ConsumeValueOrDie();
   auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
   auto f32_4_data =
-      client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+      client_->TransferToServer(f32_4_literal).ConsumeValueOrDie();
   auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
-  auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+  auto u8_4_data = client_->TransferToServer(u8_4_literal).ConsumeValueOrDie();
 
   // Match
   auto status = client_->Execute(
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 8a236db..fbdf0fc 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -101,7 +101,7 @@
   return client_->Execute(computation, arguments, &execution_options_);
 }
 
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
     const Shape* shape_with_output_layout) {
   ExecutionOptions execution_options = execution_options_;
@@ -113,7 +113,7 @@
                                      &execution_options);
 }
 
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
     XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
     const Shape* shape_with_output_layout) {
   // Build the computation, as a convenience.
@@ -121,8 +121,7 @@
   return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
 }
 
-StatusOr<std::unique_ptr<Literal>>
-ClientLibraryTestBase::ExecuteAndTransferReference(
+StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
     const Shape* shape_with_output_layout) {
   ExecutionOptions execution_options = execution_options_;
@@ -148,15 +147,15 @@
   if (!result.ok()) {
     return result.status().ToString();
   } else {
-    return result.ValueOrDie()->ToString();
+    return result.ValueOrDie().ToString();
   }
 }
 
 void ClientLibraryTestBase::ComputeAndCompareR1(
     XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  Literal expected_literal = LiteralUtil::CreateR1(expected);
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -182,7 +181,7 @@
                              const string& error_message)>& verify_output) {
   // Try with no layout requirement.
   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments));
-  verify_output(*actual, "");
+  verify_output(actual, "");
 
   // Try with all output layouts.
   std::vector<int64> minor_to_major(ShapeUtil::Rank(expected.shape()));
@@ -193,7 +192,7 @@
         AsInt64Slice(expected.shape().dimensions()), minor_to_major);
     TF_ASSIGN_OR_RETURN(auto actual,
                         ExecuteAndTransfer(computation, arguments, &layout));
-    verify_output(*actual,
+    verify_output(actual,
                   absl::StrCat("Test with output layout: ",
                                ShapeUtil::HumanStringWithLayout(layout)));
   } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end()));
@@ -218,9 +217,9 @@
       TF_ASSIGN_OR_RETURN(auto literal,
                           client_->Transfer(*arguments[index], nullptr));
       // Skip tuples because they don't have a rank.
-      if (ShapeUtil::IsTuple(literal->shape())) {
+      if (ShapeUtil::IsTuple(literal.shape())) {
         layout_strings.push_back(
-            ShapeUtil::HumanStringWithLayout(literal->shape()));
+            ShapeUtil::HumanStringWithLayout(literal.shape()));
         arguments_with_layout.push_back(arguments[index]);
         TF_RETURN_IF_ERROR(choose(index + 1));
         arguments_with_layout.pop_back();
@@ -228,15 +227,15 @@
         return Status::OK();
       }
 
-      std::vector<int64> minor_to_major(ShapeUtil::Rank(literal->shape()));
+      std::vector<int64> minor_to_major(ShapeUtil::Rank(literal.shape()));
       std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
       do {
         auto literal_relayout =
-            literal->Relayout(LayoutUtil::MakeLayout(minor_to_major));
+            literal.Relayout(LayoutUtil::MakeLayout(minor_to_major));
         layout_strings.push_back(
-            ShapeUtil::HumanStringWithLayout(literal_relayout->shape()));
+            ShapeUtil::HumanStringWithLayout(literal_relayout.shape()));
         TF_ASSIGN_OR_RETURN(auto data,
-                            client_->TransferToServer(*literal_relayout));
+                            client_->TransferToServer(literal_relayout));
         arguments_with_layout.push_back(data.get());
         TF_RETURN_IF_ERROR(choose(index + 1));
         arguments_with_layout.pop_back();
@@ -256,7 +255,7 @@
     for (const auto& str : layout_strings) {
       absl::StrAppend(&error_message, str, " ");
     }
-    verify_output(*actual, error_message);
+    verify_output(actual, error_message);
     return Status::OK();
   };
 
@@ -290,11 +289,11 @@
   // We allow using a float expected literal for a bfloat16 output. In this
   // case, we need to convert the expected literal to bfloat16.
   const Literal* expected_ptr = &expected;
-  std::unique_ptr<Literal> converted_expected;
+  Literal converted_expected;
   Shape layout_shape;
   if (use_bfloat16_) {
     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
-    expected_ptr = converted_expected.get();
+    expected_ptr = &converted_expected;
     if (shape_with_layout != nullptr) {
       layout_shape = *shape_with_layout;
       ShapeUtil::ForEachMutableSubshape(
@@ -319,7 +318,7 @@
   }
   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
                                                       shape_with_layout));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual));
   return Status::OK();
 }
 
@@ -346,11 +345,11 @@
   // We allow using a float expected literal for a bfloat16 output. In this
   // case, we need to convert the expected literal to bfloat16.
   const Literal* expected_ptr = &expected;
-  std::unique_ptr<Literal> converted_expected;
+  Literal converted_expected;
   Shape layout_shape;
   if (use_bfloat16_) {
     converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
-    expected_ptr = converted_expected.get();
+    expected_ptr = &converted_expected;
     if (shape_with_layout != nullptr) {
       layout_shape = *shape_with_layout;
       ShapeUtil::ForEachMutableSubshape(
@@ -376,7 +375,7 @@
   }
   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
                                                       shape_with_layout));
-  EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error));
   return Status::OK();
 }
 
@@ -391,12 +390,12 @@
   auto actual = actual_status.ConsumeValueOrDie();
 
   // Turn the expected value into a literal.
-  std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+  Literal expected_literal = LiteralUtil::CreateR1U8(expected);
 
-  VLOG(1) << "expected: " << expected_literal->ToString();
-  VLOG(1) << "actual:   " << actual->ToString();
+  VLOG(1) << "expected: " << expected_literal.ToString();
+  VLOG(1) << "actual:   " << actual.ToString();
 
-  EXPECT_EQ(expected, actual->GetR1U8AsString());
+  EXPECT_EQ(expected, actual.GetR1U8AsString());
 }
 
 void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -408,7 +407,7 @@
     return;
   }
   auto actual = actual_status.ConsumeValueOrDie();
-  EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
 }
 
 void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -420,7 +419,7 @@
     return;
   }
   auto actual = actual_status.ConsumeValueOrDie();
-  EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, error));
 }
 
 void ClientLibraryTestBase::ComputeAndCompare(
@@ -430,9 +429,9 @@
   if (!status_or_data.ok()) {
     return;
   }
-  std::unique_ptr<Literal> reference, result;
+  Literal reference, result;
   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(reference, result));
 }
 
 void ClientLibraryTestBase::ComputeAndCompare(
@@ -442,12 +441,12 @@
   if (!status_or_data.ok()) {
     return;
   }
-  std::unique_ptr<Literal> reference, result;
+  Literal reference, result;
   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
+  EXPECT_TRUE(LiteralTestUtil::Near(reference, result, error));
 }
 
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+StatusOr<std::pair<Literal, Literal>>
 ClientLibraryTestBase::ComputeValueAndReference(
     XlaBuilder* builder, absl::Span<const Literal> arguments) {
   // Transfer the arguments to the executor service. We put the unique_ptr's
@@ -569,8 +568,8 @@
 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
                                                        XlaBuilder* builder) {
   return ConstantLiteral(builder, use_bfloat16_
-                                      ? *LiteralUtil::ConvertF32ToBF16(literal)
-                                      : literal);
+                                      ? LiteralUtil::ConvertF32ToBF16(literal)
+                                      : LiteralSlice(literal));
 }
 
 std::unique_ptr<GlobalData>
@@ -600,7 +599,7 @@
 Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16(
     const Literal& literal) {
   if (use_bfloat16_) {
-    return std::move(*LiteralUtil::ConvertF32ToBF16(literal));
+    return LiteralUtil::ConvertF32ToBF16(literal);
   }
   return literal.Clone();
 }
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 22dfdfb..9d32f4f 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -95,11 +95,11 @@
   StatusOr<std::unique_ptr<GlobalData>> Execute(
       XlaBuilder* builder, absl::Span<GlobalData* const> arguments);
 
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+  StatusOr<Literal> ExecuteAndTransfer(
       XlaBuilder* builder, absl::Span<GlobalData* const> arguments,
       const Shape* shape_with_output_layout = nullptr);
 
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+  StatusOr<Literal> ExecuteAndTransfer(
       const XlaComputation& computation,
       absl::Span<GlobalData* const> arguments,
       const Shape* shape_with_output_layout = nullptr);
@@ -107,7 +107,7 @@
   // This executes the computation via the reference client (which connects a
   // interpreter backend). The result is used as the expected values of the
   // computation.
-  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransferReference(
+  StatusOr<Literal> ExecuteAndTransferReference(
       const XlaComputation& computation,
       absl::Span<GlobalData* const> arguments,
       const Shape* shape_with_output_layout = nullptr);
@@ -282,7 +282,7 @@
 
   template <class T>
   XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
-    return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
+    return AddParam(LiteralUtil::CreateFromArray(argument), builder);
   }
 
   // Creates a constant instruction with the given literal. When the
@@ -297,14 +297,14 @@
   template <typename NativeT>
   XlaOp CreateConstantFromArray(const Array<NativeT>& array,
                                 XlaBuilder* builder) {
-    return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+    return CreateConstantFromLiteral(LiteralUtil::CreateFromArray(array),
                                      builder);
   }
 
   // Same as CreateConstantFromArray, but for scalars.
   template <typename NativeT>
   XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
-    return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
+    return CreateConstantFromLiteral(LiteralUtil::CreateR0<NativeT>(value),
                                      builder);
   }
 
@@ -375,9 +375,8 @@
   // Executes the computation and calculates the expected reference value using
   // the reference client. Returns two literals in the order of (expected,
   // actual).
-  StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
-  ComputeValueAndReference(XlaBuilder* builder,
-                           absl::Span<const Literal> arguments);
+  StatusOr<std::pair<Literal, Literal>> ComputeValueAndReference(
+      XlaBuilder* builder, absl::Span<const Literal> arguments);
 
   Client* client_;
   Client* ref_client_;  // To compute reference result.
@@ -412,9 +411,8 @@
 void ClientLibraryTestBase::ComputeAndCompareR0(
     XlaBuilder* builder, NativeT expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal =
-      LiteralUtil::CreateR0<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -428,9 +426,8 @@
                     std::is_same<NativeT, half>::value ||
                     std::is_same<NativeT, complex64>::value,
                 "Float or complex type required when specifying an ErrorSpec");
-  std::unique_ptr<Literal> expected_literal =
-      LiteralUtil::CreateR0<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  Literal expected_literal = LiteralUtil::CreateR0<NativeT>(expected);
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments, error);
 }
 
@@ -438,9 +435,8 @@
 void ClientLibraryTestBase::ComputeAndCompareR1(
     XlaBuilder* builder, absl::Span<const NativeT> expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal =
-      LiteralUtil::CreateR1<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -454,9 +450,8 @@
                     std::is_same<NativeT, half>::value ||
                     std::is_same<NativeT, complex64>::value,
                 "Float or complex type required when specifying an ErrorSpec");
-  std::unique_ptr<Literal> expected_literal =
-      LiteralUtil::CreateR1<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  Literal expected_literal = LiteralUtil::CreateR1<NativeT>(expected);
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments, error);
 }
 
@@ -464,9 +459,9 @@
 void ClientLibraryTestBase::ComputeAndCompareR2(
     XlaBuilder* builder, const Array2D<NativeT>& expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -480,9 +475,9 @@
                     std::is_same<NativeT, half>::value ||
                     std::is_same<NativeT, complex64>::value,
                 "Float or complex type required when specifying an ErrorSpec");
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments, error);
 }
 
@@ -490,9 +485,9 @@
 void ClientLibraryTestBase::ComputeAndCompareR3(
     XlaBuilder* builder, const Array3D<NativeT>& expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -506,9 +501,9 @@
                     std::is_same<NativeT, half>::value ||
                     std::is_same<NativeT, complex64>::value,
                 "Float or complex type required when specifying an ErrorSpec");
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments, error);
 }
 
@@ -516,9 +511,9 @@
 void ClientLibraryTestBase::ComputeAndCompareR4(
     XlaBuilder* builder, const Array4D<NativeT>& expected,
     absl::Span<GlobalData* const> arguments) {
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments);
 }
 
@@ -532,9 +527,9 @@
                     std::is_same<NativeT, half>::value ||
                     std::is_same<NativeT, complex64>::value,
                 "Float or complex type required when specifying an ErrorSpec");
-  std::unique_ptr<Literal> expected_literal =
+  Literal expected_literal =
       LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
-  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, expected_literal,
                                                   arguments, error);
 }
 
@@ -542,13 +537,13 @@
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
     NativeT value, int64 parameter_number, const string& name,
     XlaBuilder* builder, XlaOp* data_handle) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
-  if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralUtil::ConvertF32ToBF16(*literal);
+  Literal literal = LiteralUtil::CreateR0(value);
+  if (use_bfloat16_ && literal.shape().element_type() == F32) {
+    literal = LiteralUtil::ConvertF32ToBF16(literal);
   }
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
   return data;
 }
 
@@ -556,13 +551,13 @@
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
     absl::Span<const NativeT> values, int64 parameter_number,
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
-  if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralUtil::ConvertF32ToBF16(*literal);
+  Literal literal = LiteralUtil::CreateR1(values);
+  if (use_bfloat16_ && literal.shape().element_type() == F32) {
+    literal = LiteralUtil::ConvertF32ToBF16(literal);
   }
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
   return data;
 }
 
@@ -570,13 +565,13 @@
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
     const Array2D<NativeT>& array_2d, int64 parameter_number,
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
-  if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralUtil::ConvertF32ToBF16(*literal);
+  Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+  if (use_bfloat16_ && literal.shape().element_type() == F32) {
+    literal = LiteralUtil::ConvertF32ToBF16(literal);
   }
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
   return data;
 }
 
@@ -584,13 +579,13 @@
 std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
     const Array3D<NativeT>& array_3d, int64 parameter_number,
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
-  if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralUtil::ConvertF32ToBF16(*literal);
+  Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+  if (use_bfloat16_ && literal.shape().element_type() == F32) {
+    literal = LiteralUtil::ConvertF32ToBF16(literal);
   }
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  *data_handle = Parameter(builder, parameter_number, literal.shape(), name);
   return data;
 }
 
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index c898dac..6f2ca84 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -55,16 +55,15 @@
           std::unique_ptr<GlobalData> data,
           client_->Execute(computation, {}, &execution_options));
 
-      std::unique_ptr<Literal> expected_literal =
-          LiteralUtil::CreateR2WithLayout<int32>(
-              {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
+      Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+          {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
 
       TF_ASSERT_OK_AND_ASSIGN(
-          auto computed, client_->Transfer(*data, &expected_literal->shape()));
+          auto computed, client_->Transfer(*data, &expected_literal.shape()));
 
       ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
-          expected_literal->shape(), computed->shape()));
-      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+          expected_literal.shape(), computed.shape()));
+      EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
     }
   }
 }
@@ -91,19 +90,19 @@
       auto result,
       client_->ExecuteAndTransfer(computation, {}, &execution_options));
   LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
-                                        LiteralSlice(*result, {0}));
+                                        LiteralSlice(result, {0}));
   LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
-                                        LiteralSlice(*result, {1}));
+                                        LiteralSlice(result, {1}));
 
-  EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
-  EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+  EXPECT_TRUE(ShapeUtil::IsTuple(result.shape()));
+  EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.shape()));
 
   EXPECT_TRUE(ShapeUtil::Equal(
-      ShapeUtil::GetTupleElementShape(result->shape(), 0),
+      ShapeUtil::GetTupleElementShape(result.shape(), 0),
       ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
                                      /*minor_to_major=*/{0, 1})));
   EXPECT_TRUE(ShapeUtil::Equal(
-      ShapeUtil::GetTupleElementShape(result->shape(), 1),
+      ShapeUtil::GetTupleElementShape(result.shape(), 1),
       ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
                                      /*minor_to_major=*/{1, 0})));
 }
@@ -114,7 +113,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
                           client_->TransferToServer(
-                              *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
+                              LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
 
   XlaBuilder b(TestName() + ".add");
   Add(Parameter(&b, 0, shape, "param_0"),
@@ -140,9 +139,9 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       auto result_literal,
-      client_->Transfer(*results[0], &expected_result->shape()));
+      client_->Transfer(*results[0], &expected_result.shape()));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result_literal));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 03d5696..6ef7ca0 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -42,14 +42,14 @@
                                absl::Span<GlobalData* const> arguments,
                                float expected_result, bool expect_cache_hit) {
     ExecutionProfile execution_profile;
-    std::unique_ptr<Literal> result =
+    Literal result =
         client_
             ->ExecuteAndTransfer(computation, arguments,
                                  /*execution_options=*/&execution_options_,
                                  &execution_profile)
             .ConsumeValueOrDie();
     EXPECT_TRUE(LiteralTestUtil::Near(
-        *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
+        LiteralUtil::CreateR0<float>(expected_result), result, error_spec_));
     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
   }
 
@@ -63,10 +63,9 @@
                            ->Execute(computation, arguments,
                                      &execution_options_, &execution_profile)
                            .ConsumeValueOrDie();
-    std::unique_ptr<Literal> result =
-        client_->Transfer(*data_handle).ConsumeValueOrDie();
+    Literal result = client_->Transfer(*data_handle).ConsumeValueOrDie();
     EXPECT_TRUE(LiteralTestUtil::Near(
-        *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
+        LiteralUtil::CreateR2<float>(expected_result), result, error_spec_));
     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
   }
 
@@ -88,13 +87,13 @@
 XLA_TEST_F(CompilationCacheTest,
            DISABLED_ComputationCalledWithDifferentParameters) {
   std::unique_ptr<GlobalData> data_42 =
-      client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+      client_->TransferToServer(LiteralUtil::CreateR0<float>(42.0f))
           .ConsumeValueOrDie();
   std::unique_ptr<GlobalData> data_123 =
-      client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+      client_->TransferToServer(LiteralUtil::CreateR0<float>(123.0f))
           .ConsumeValueOrDie();
   std::unique_ptr<GlobalData> data_456 =
-      client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+      client_->TransferToServer(LiteralUtil::CreateR0<float>(456.0f))
           .ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
@@ -145,12 +144,12 @@
   auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
   auto rowmaj_handle =
-      client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+      client_->TransferToServer(rowmaj_array).ConsumeValueOrDie();
 
   auto colmaj_array = LiteralUtil::CreateR2WithLayout(
       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
   auto colmaj_handle =
-      client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+      client_->TransferToServer(colmaj_array).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 8226b6d..3b0414a 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -69,9 +69,9 @@
     LOG(FATAL) << "invalid client_type value";
   }
 
-  StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
-      Client* client, const XlaOp& operand, XlaBuilder* builder,
-      Layout* output_layout = nullptr) {
+  StatusOr<Literal> ComputeConstantLiteral(Client* client, const XlaOp& operand,
+                                           XlaBuilder* builder,
+                                           Layout* output_layout = nullptr) {
     TF_ASSIGN_OR_RETURN(auto subgraph, builder->BuildConstantSubGraph(operand));
     TF_ASSIGN_OR_RETURN(auto computed,
                         client->ComputeConstant(subgraph, output_layout));
@@ -83,7 +83,7 @@
                                          XlaBuilder* builder) {
     TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand,
                                                              builder, nullptr));
-    return literal->Get<Scalar>({});
+    return literal.Get<Scalar>({});
   }
 
   bool IsConstant(const XlaOp& operand, XlaBuilder* builder) {
@@ -206,9 +206,8 @@
 
     TF_ASSERT_OK_AND_ASSIGN(auto computed,
                             ComputeConstantLiteral(client, computation, &b));
-    std::unique_ptr<Literal> expected_literal =
-        LiteralUtil::CreateR1<int32>({4, 6});
-    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+    Literal expected_literal = LiteralUtil::CreateR1<int32>({4, 6});
+    EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
   }
 }
 
@@ -221,8 +220,8 @@
 
     TF_ASSERT_OK_AND_ASSIGN(auto computed,
                             ComputeConstantLiteral(client, computation, &b));
-    std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
-    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+    Literal expected_literal = LiteralUtil::CreateR0<int32>(5);
+    EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
   }
 }
 
@@ -241,12 +240,11 @@
                                  ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
                              &b, &layout_proto));
 
-      std::unique_ptr<Literal> expected_literal =
-          LiteralUtil::CreateR2WithLayout<int32>(
-              {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
+      Literal expected_literal = LiteralUtil::CreateR2WithLayout<int32>(
+          {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
       ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
-          expected_literal->shape(), computed->shape()));
-      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
+          expected_literal.shape(), computed.shape()));
+      EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, computed));
     }
   }
 }
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index be01747..9811a01 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -536,8 +536,8 @@
   auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
   auto x_literal = LiteralUtil::CreateR0<float>(2.f);
   auto y_literal = LiteralUtil::CreateR0<float>(3.f);
-  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
-  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
   auto x = Parameter(&builder, 0, f32_scalar, "x");
@@ -559,12 +559,12 @@
   auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
-  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
-  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
-  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+  auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
-  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
   auto y = Parameter(&builder, 1, f32_scalar, "y");
   auto z = Parameter(&builder, 2, f32_scalar, "z");
   auto bcast = Broadcast(y, {5});
@@ -587,12 +587,12 @@
   auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
   auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
   auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
-  auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
-  auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
-  auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
+  auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie();
+  auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie();
+  auto z_data = client_->TransferToServer(z_literal).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
-  auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+  auto x = Parameter(&builder, 0, x_literal.shape(), "x");
   auto y = Parameter(&builder, 1, f32_scalar, "y");
   auto z = Parameter(&builder, 2, f32_scalar, "y");
   auto y_bcast = Broadcast(y, {1, 5, 7});
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 25d10ab..32cac49 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -359,8 +359,8 @@
 
   ComputeAndCompareTuple(
       &builder,
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
-                               LiteralUtil::CreateR0<float>(25.0f).get()}),
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
+                                        LiteralUtil::CreateR0<float>(25.0f)}),
       {pred_arg.get()}, error_spec_);
 }
 
@@ -375,12 +375,11 @@
   Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
               CreateR1TupleFloorComputation());
 
-  ComputeAndCompareTuple(
-      &builder,
-      *LiteralUtil::MakeTuple(
-          {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
-           LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
-      {pred_arg.get()}, error_spec_);
+  ComputeAndCompareTuple(&builder,
+                         LiteralUtil::MakeTupleFromSlices(
+                             {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
+                              LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
+                         {pred_arg.get()}, error_spec_);
 }
 
 // Test true and false computations that return a tuple of a predicate, a
@@ -415,13 +414,12 @@
   Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
               false_builder_result.ConsumeValueOrDie());
 
-  ComputeAndCompareTuple(
-      &builder,
-      *LiteralUtil::MakeTuple(
-          {LiteralUtil::CreateR0<bool>(true).get(),
-           LiteralUtil::CreateR0<float>(12.2f).get(),
-           LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
-      {pred_arg.get()}, error_spec_);
+  ComputeAndCompareTuple(&builder,
+                         LiteralUtil::MakeTupleFromSlices(
+                             {LiteralUtil::CreateR0<bool>(true),
+                              LiteralUtil::CreateR0<float>(12.2f),
+                              LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
+                         {pred_arg.get()}, error_spec_);
 }
 
 // Test true and false computations that return a nested tuple.
@@ -463,15 +461,13 @@
 
   ComputeAndCompareTuple(
       &builder,
-      *LiteralUtil::MakeTuple(
-          {LiteralUtil::MakeTuple(
-               {LiteralUtil::CreateR0<float>(46.6f).get(),
-                LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
-               .get(),
-           LiteralUtil::MakeTuple(
-               {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
-                LiteralUtil::CreateR0<float>(9.3f).get()})
-               .get()}),
+      LiteralUtil::MakeTupleFromSlices(
+          {LiteralUtil::MakeTupleFromSlices(
+               {LiteralUtil::CreateR0<float>(46.6f),
+                LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
+           LiteralUtil::MakeTupleFromSlices(
+               {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
+                LiteralUtil::CreateR0<float>(9.3f)})}),
       {pred_arg.get()}, error_spec_);
 }
 
@@ -633,8 +629,8 @@
 
     ComputeAndCompareTuple(
         &builder,
-        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
-                                 LiteralUtil::CreateR0<float>(b).get()}),
+        LiteralUtil::MakeTupleFromSlices(
+            {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
         {x_arg.get(), y_arg.get()}, error_spec_);
   };
 
@@ -669,10 +665,10 @@
   {
     // Pred is true case.
     std::vector<Literal> args;
-    args.push_back(std::move(
-        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
-                                 LiteralUtil::CreateR0<int32>(-42).get()})));
-    args.push_back(std::move(*LiteralUtil::CreateR0<bool>(true)));
+    args.push_back(
+        LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+                                          LiteralUtil::CreateR0<int32>(-42)}));
+    args.push_back(LiteralUtil::CreateR0<bool>(true));
     XlaBuilder builder(TestName() + ".main");
     auto p = Parameter(&builder, 0, tuple2, "p0");
     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
@@ -682,10 +678,10 @@
   {
     // Pred is false case.
     std::vector<Literal> args;
-    args.push_back(std::move(
-        *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(123).get(),
-                                 LiteralUtil::CreateR0<int32>(-42).get()})));
-    args.push_back(std::move(*LiteralUtil::CreateR0<bool>(false)));
+    args.push_back(
+        LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<int32>(123),
+                                          LiteralUtil::CreateR0<int32>(-42)}));
+    args.push_back(LiteralUtil::CreateR0<bool>(false));
     XlaBuilder builder(TestName() + ".main");
     auto p = Parameter(&builder, 0, tuple2, "p0");
     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 4937574..72ff1e7 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -110,7 +110,7 @@
 
 TEST_F(ConstantsTest, Empty_3x0x2) {
   XlaBuilder builder(TestName());
-  ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+  ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(
                                 Array3D<float>(3, 0, 2)));
 
   ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
@@ -126,7 +126,7 @@
       {{5.f, 6.f},   // y0
        {7.f, 8.f}},  // y1
   });
-  ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+  ConstantLiteral(&builder, LiteralUtil::CreateR3FromArray3D<float>(array3d));
 
   ComputeAndCompareR3<float>(&builder, array3d, {});
 }
@@ -140,12 +140,11 @@
       {5.0f, 4.4f},   // p2
   });
   input_array.FillWithPZ(pz);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4D(input_array);
+  Literal input_literal = LiteralUtil::CreateR4FromArray4D(input_array);
 
   {
     XlaBuilder builder(TestName());
-    ConstantLiteral(&builder, *input_literal);
+    ConstantLiteral(&builder, input_literal);
     ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
   }
 
@@ -159,23 +158,21 @@
 // TODO(b/29263943): Support tuple constants.
 TEST_F(ConstantsTest, DISABLED_TupleConstant) {
   XlaBuilder builder(TestName());
-  ConstantLiteral(&builder,
-                  *LiteralUtil::MakeTuple(
-                      {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
-                       LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+  ConstantLiteral(&builder, LiteralUtil::MakeTupleFromSlices(
+                                {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}),
+                                 LiteralUtil::CreateR1<float>({2.0, 42})}));
 
-  std::unique_ptr<Literal> result =
-      ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
+  Literal result = ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
 
   LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
-                                       LiteralSlice(*result, {0}), error_spec_);
-  LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+                                       LiteralSlice(result, {0}), error_spec_);
+  LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(result, {1}),
                                        error_spec_);
 }
 
 TEST_F(ConstantsTest, Token) {
   XlaBuilder builder(TestName());
-  ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+  ConstantLiteral(&builder, LiteralUtil::CreateToken());
   // TODO(b/80000000): tokens cannot be returned from computations.
   Tuple(&builder, {});
   TF_ASSERT_OK(Execute(&builder, {}).status());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 7a203d6..5f063e6 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -210,10 +210,10 @@
       static_cast<int64>(0x8000008000000000LL),
       static_cast<int64>(0x8000010000000000LL),
   };
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<int64>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, F32);
 
@@ -229,10 +229,10 @@
   std::vector<uint32> arg{0,          1,          0x1000,     0x7fffffff,
                           0x80000000, 0x80000001, 0x80000002, 0x80000003,
                           0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, F32);
 
@@ -247,10 +247,10 @@
   XlaBuilder builder(TestName());
   std::vector<float> arg{0.0f,        1.0f,          16777216.0f,
                          16777218.0f, 2147483647.0f, 4294967040.0f};
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, U32);
 
@@ -264,10 +264,10 @@
 XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
   XlaBuilder builder(TestName());
   std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, S64);
 
@@ -281,10 +281,10 @@
 XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
   XlaBuilder builder(TestName());
   std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<int32>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, S64);
 
@@ -318,10 +318,10 @@
                          9223370937343148032.f,
                          -9223371487098961920.f,
                          -9223370937343148032.f};
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
-  auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
+  Literal arg_literal = LiteralUtil::CreateR1<float>({arg});
+  auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param");
   std::unique_ptr<GlobalData> arg_data =
-      client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
+      client_->TransferToServer(arg_literal).ConsumeValueOrDie();
 
   ConvertElementType(arg_param, S64);
 
@@ -456,7 +456,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> dot_lhs_handle,
-      client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
+      client_->TransferToServer(LiteralUtil::CreateR1<half>(input)));
 
   XlaBuilder builder(TestName());
   ConvertElementType(
@@ -476,7 +476,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> dot_lhs_handle,
-      client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
+      client_->TransferToServer(LiteralUtil::CreateR1<float>(input)));
 
   XlaBuilder builder(TestName());
   ConvertElementType(
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 38b6da4..fd98bf2 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,8 +93,7 @@
   auto weight_array = absl::make_unique<Array4D<float>>(4, 3, 1, 1);
   weight_array->FillWithMultiples(0.2);
   auto weight_data =
-      client_
-          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+      client_->TransferToServer(LiteralUtil::CreateR4FromArray4D(*weight_array))
           .ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index d2c6478..070b092 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -123,8 +123,8 @@
     }));
 
     ComputeAndCompare(&builder,
-                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
+                      {LiteralUtil::CreateFromArray(input_data),
+                       LiteralUtil::CreateFromArray(filter_data)},
                       error_spec_);
   }
 };
@@ -157,8 +157,8 @@
         {7.0f, 8.0f},
     }));
     ComputeAndCompare(&builder,
-                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
+                      {LiteralUtil::CreateFromArray(input_data),
+                       LiteralUtil::CreateFromArray(filter_data)},
                       error_spec_);
   }
 };
@@ -192,8 +192,8 @@
     }));
 
     ComputeAndCompare(&builder,
-                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
+                      {LiteralUtil::CreateFromArray(input_data),
+                       LiteralUtil::CreateFromArray(filter_data)},
                       error_spec_);
   }
 };
@@ -224,8 +224,8 @@
         {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
     // clang-format on
     ComputeAndCompare(&builder,
-                      {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                       std::move(*LiteralUtil::CreateFromArray(filter_data))},
+                      {LiteralUtil::CreateFromArray(input_data),
+                       LiteralUtil::CreateFromArray(filter_data)},
                       error_spec_);
   }
 };
@@ -249,10 +249,10 @@
   Array3D<float> expected({{{510, 610, 710, 810}}});
 
   auto input_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
           .ConsumeValueOrDie();
   auto filter_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
           .ConsumeValueOrDie();
 
   ComputeAndCompareR3<float>(&builder, expected,
@@ -284,10 +284,10 @@
     Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
 
     auto input_literal =
-        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
             .ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
             .ConsumeValueOrDie();
 
     ComputeAndCompareR3<T>(&builder, expected,
@@ -319,10 +319,10 @@
   Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
 
   auto input_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
           .ConsumeValueOrDie();
   auto filter_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
           .ConsumeValueOrDie();
 
   ComputeAndCompareR3<float>(&builder, expected,
@@ -350,10 +350,10 @@
   Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
 
   auto input_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
           .ConsumeValueOrDie();
   auto filter_literal =
-      client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+      client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
           .ConsumeValueOrDie();
 
   ComputeAndCompareR3<float>(&builder, expected,
@@ -386,10 +386,10 @@
         {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
 
     auto input_literal =
-        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
             .ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+        client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
             .ConsumeValueOrDie();
 
     ComputeAndCompareR3<T>(&builder, expected,
@@ -435,23 +435,23 @@
   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
   iota(input_elems.begin(), input_elems.end(), 1.0f);
   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
-  auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+  auto input_r5 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 
   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
   iota(filter_elems.begin(), filter_elems.end(), 1.0f);
   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
-  auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+  auto filter_r5 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 
   auto expected_r1 = LiteralUtil::CreateR1<float>(
       {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
        38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
-  auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
+  auto expected_r5 = expected_r1.Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
 
-  auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+  auto input_literal = client_->TransferToServer(input_r5).ConsumeValueOrDie();
   auto filter_literal =
-      client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+      client_->TransferToServer(filter_r5).ConsumeValueOrDie();
 
-  ComputeAndCompareLiteral(&builder, *expected_r5,
+  ComputeAndCompareLiteral(&builder, expected_r5,
                            {input_literal.get(), filter_literal.get()},
                            error_spec_);
 }
@@ -498,23 +498,23 @@
     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
     iota_int_init_value(input_elems, 1);
     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
-    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 
     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
     iota_int_init_value(filter_elems, 1);
     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
-    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 
     auto expected_r1 = LiteralUtil::CreateR1<T>(
         {static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
-    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
+    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
 
     auto input_literal =
-        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+        client_->TransferToServer(input_r4).ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 
-    ComputeAndCompareLiteral(&builder, *expected_r4,
+    ComputeAndCompareLiteral(&builder, expected_r4,
                              {input_literal.get(), filter_literal.get()},
                              error_spec_);
   }
@@ -558,12 +558,12 @@
     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
     iota_int_init_value(input_elems, 1);
     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
-    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 
     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
     iota_int_init_value(filter_elems, 1);
     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
-    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 
     auto expected_r1 = LiteralUtil::CreateR1<T>(
         {static_cast<T>(16029), static_cast<T>(16218), static_cast<T>(16407),
@@ -571,14 +571,14 @@
          static_cast<T>(18369), static_cast<T>(18576), static_cast<T>(18783),
          static_cast<T>(19620), static_cast<T>(19836), static_cast<T>(20052),
          static_cast<T>(20925), static_cast<T>(21150), static_cast<T>(21375)});
-    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
+    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 15}).ConsumeValueOrDie();
 
     auto input_literal =
-        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+        client_->TransferToServer(input_r4).ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 
-    ComputeAndCompareLiteral(&builder, *expected_r4,
+    ComputeAndCompareLiteral(&builder, expected_r4,
                              {input_literal.get(), filter_literal.get()},
                              error_spec_);
   }
@@ -624,26 +624,26 @@
     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
     iota_int_init_value(input_elems, 1);
     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
-    auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+    auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 
     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
     iota_int_init_value(filter_elems, 1);
     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
-    auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+    auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 
     auto expected_r1 = LiteralUtil::CreateR1<T>(
         {static_cast<T>(5076), static_cast<T>(5160), static_cast<T>(5244),
          static_cast<T>(5328), static_cast<T>(6164), static_cast<T>(6264),
          static_cast<T>(6364), static_cast<T>(6464), static_cast<T>(7380),
          static_cast<T>(7496), static_cast<T>(7612), static_cast<T>(7728)});
-    auto expected_r4 = expected_r1->Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
+    auto expected_r4 = expected_r1.Reshape({1, 1, 1, 12}).ConsumeValueOrDie();
 
     auto input_literal =
-        client_->TransferToServer(*input_r4).ConsumeValueOrDie();
+        client_->TransferToServer(input_r4).ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
+        client_->TransferToServer(filter_r4).ConsumeValueOrDie();
 
-    ComputeAndCompareLiteral(&builder, *expected_r4,
+    ComputeAndCompareLiteral(&builder, expected_r4,
                              {input_literal.get(), filter_literal.get()},
                              error_spec_);
   }
@@ -692,8 +692,8 @@
   expected_result.Fill(0);
 
   ComputeAndCompare(&builder,
-                    {std::move(*LiteralUtil::CreateFromArray(param0)),
-                     std::move(*LiteralUtil::CreateFromArray(param1))},
+                    {LiteralUtil::CreateFromArray(param0),
+                     LiteralUtil::CreateFromArray(param1)},
                     error_spec_);
 }
 
@@ -749,26 +749,25 @@
     std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
                                static_cast<T>(1.0f));
     auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
-    auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
+    auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
 
     std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
                                 static_cast<T>(1.0f));
 
     auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
-    auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
+    auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
 
     std::vector<T> expect_elems(batch * output_feature * num_windows,
                                 static_cast<T>(window_size * input_feature));
     auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
-    auto expected_r3 =
-        expected_r1->Reshape({batch, num_windows, output_feature})
-            .ConsumeValueOrDie();
+    auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
+                           .ConsumeValueOrDie();
 
     auto input_literal =
-        client_->TransferToServer(*input_r3).ConsumeValueOrDie();
+        client_->TransferToServer(input_r3).ConsumeValueOrDie();
     auto filter_literal =
-        client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
-    ComputeAndCompareLiteral(&builder, *expected_r3,
+        client_->TransferToServer(filter_r3).ConsumeValueOrDie();
+    ComputeAndCompareLiteral(&builder, expected_r3,
                              {input_literal.get(), filter_literal.get()},
                              error_spec_);
   }
@@ -868,8 +867,8 @@
   }));
 
   ComputeAndCompare(&builder,
-                    {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                     std::move(*LiteralUtil::CreateFromArray(filter_data))},
+                    {LiteralUtil::CreateFromArray(input_data),
+                     LiteralUtil::CreateFromArray(filter_data)},
                     error_spec_);
 }
 
@@ -891,9 +890,44 @@
   Array4D<float> filter_data(1, 1, 1, 2);
   filter_data.FillIota(10);
 
-  ComputeAndCompare(&builder,
-                    {std::move(*LiteralUtil::CreateFromArray(input_data)),
-                     std::move(*LiteralUtil::CreateFromArray(filter_data))});
+  ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data),
+                               LiteralUtil::CreateFromArray(filter_data)});
+}
+
+XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) {
+  XlaBuilder builder(TestName());
+  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 64, 100, 100});
+  Array4D<float> input_data(1, 64, 100, 100);
+  input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45321);
+  Shape filter_shape = ShapeUtil::MakeShape(F32, {7, 7, 1, 64});
+  Array4D<float> filter_data(7, 7, 1, 64);
+  input_data.FillRandom(/*value=*/0.023, 0.001, /*seed=*/45320);
+  auto input = Parameter(&builder, 0, input_shape, "input");
+  auto filter = ConstantR4FromArray4D(&builder, filter_data);
+
+  // Specify bf01_01io->bf01 as dimension numbers.
+  ConvolutionDimensionNumbers dnums;
+  // Input
+  dnums.set_input_feature_dimension(1);
+  dnums.set_input_batch_dimension(0);
+  dnums.add_input_spatial_dimensions(2);
+  dnums.add_input_spatial_dimensions(3);
+  // Kernel
+  dnums.set_kernel_input_feature_dimension(2);
+  dnums.set_kernel_output_feature_dimension(3);
+  dnums.add_kernel_spatial_dimensions(0);
+  dnums.add_kernel_spatial_dimensions(1);
+  // Output
+  dnums.set_output_batch_dimension(0);
+  dnums.set_output_feature_dimension(1);
+  dnums.add_output_spatial_dimensions(2);
+  dnums.add_output_spatial_dimensions(3);
+  ConvGeneral(input, filter, /*window_strides=*/{1, 1},
+              /*padding=*/{{3, 3}, {3, 3}}, /*dimension_numbers=*/dnums,
+              /*feature_group_count=*/64);
+
+  ComputeAndCompare(&builder, {LiteralUtil::CreateFromArray(input_data)},
+                    error_spec_);
 }
 
 class ConvolutionHloTest : public HloTestBase {};
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 6784c16..ba3e9c4 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -1335,23 +1335,23 @@
 
   auto gradients_flat = LiteralUtil::CreateR1<float>({1});
   auto gradients_literal =
-      gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
-  auto gradients = ConstantLiteral(&builder, *gradients_literal);
+      gradients_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+  auto gradients = ConstantLiteral(&builder, gradients_literal);
 
   auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
   auto weights_literal =
-      weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
-  auto weights = ConstantLiteral(&builder, *weights_literal);
+      weights_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+  auto weights = ConstantLiteral(&builder, weights_literal);
 
   auto expected_flat = LiteralUtil::CreateR1<float>({10});
   auto expected_literal =
-      expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
+      expected_flat.Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
 
   auto mirrored_weights = Rev(weights, {2, 3, 4});
   ConvWithGeneralPadding(gradients, mirrored_weights,
                          /*window_strides=*/{1, 1, 1},
                          /*padding=*/{{0, 0}, {0, 0}, {1, 1}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+  ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
 }
 
 XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
@@ -1359,17 +1359,17 @@
 
   auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
   auto activations_literal =
-      activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
-  auto activations = ConstantLiteral(&builder, *activations_literal);
+      activations_flat.Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
+  auto activations = ConstantLiteral(&builder, activations_literal);
 
   auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
   auto gradients_literal =
-      gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
-  auto gradients = ConstantLiteral(&builder, *gradients_literal);
+      gradients_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+  auto gradients = ConstantLiteral(&builder, gradients_literal);
 
   auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
   auto expected_literal =
-      expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
+      expected_flat.Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
 
   auto forward_conv =
       ConvGeneralDilated(activations, gradients,
@@ -1379,7 +1379,7 @@
                          XlaBuilder::CreateDefaultConvDimensionNumbers(
                              /*num_spatial_dims=*/3));
   Transpose(forward_conv, {0, 1, 2, 3, 4});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
+  ComputeAndCompareLiteral(&builder, expected_literal, {}, error_spec_);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 526626c..1407e68 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -40,16 +40,16 @@
  protected:
   void TestCopyOp(const Literal& literal) {
     auto builder = HloComputation::Builder(TestName());
-    auto constant = builder.AddInstruction(
-        HloInstruction::CreateConstant(literal.CloneToUnique()));
+    auto constant =
+        builder.AddInstruction(HloInstruction::CreateConstant(literal.Clone()));
     builder.AddInstruction(HloInstruction::CreateUnary(
         constant->shape(), HloOpcode::kCopy, constant));
     auto computation = builder.Build();
     auto module = CreateNewModule();
     module->AddEntryComputation(std::move(computation));
 
-    std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
-    EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
+    Literal result = ExecuteAndTransfer(std::move(module), {});
+    EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
   }
 
   void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -58,31 +58,30 @@
 };
 
 XLA_TEST_F(CopyOpTest, CopyR0Bool) {
-  TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+  TestCopyOp(LiteralUtil::CreateR0<bool>(true));
 }
 
 XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
-  TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+  TestCopyOp(LiteralUtil::CreateR1<uint32>({}));
 }
 
 XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
-  TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+  TestCopyOp(LiteralUtil::CreateR1<uint32>({1, 2, 3}));
 }
 
 XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
-  TestCopyOp(
-      *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
-                              {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+  TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+                                    {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
 }
 
 XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
-  TestCopyOp(*LiteralUtil::CreateR4(
+  TestCopyOp(LiteralUtil::CreateR4(
       {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
        {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
 }
 
 XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
-  TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+  TestCopyOp(LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
 }
 
 XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
@@ -90,7 +89,7 @@
 
   // Copy literal to device to use as parameter.
   auto literal = LiteralUtil::CreateR0<float>(42.0);
-  Shape shape = literal->shape();
+  Shape shape = literal.shape();
 
   auto param0 = builder.AddInstruction(
       HloInstruction::CreateParameter(0, shape, "param0"));
@@ -102,9 +101,8 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(std::move(computation));
 
-  std::unique_ptr<Literal> result =
-      ExecuteAndTransfer(std::move(module), {literal.get()});
-  LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+  Literal result = ExecuteAndTransfer(std::move(module), {&literal});
+  LiteralTestUtil::ExpectR0Near<float>(42.0f, result, error_spec_);
 }
 
 XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
@@ -123,19 +121,17 @@
 
   auto module = CreateNewModule();
   module->AddEntryComputation(std::move(computation));
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
-  LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+  Literal result = ExecuteAndTransfer(std::move(module), {});
+  LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, result,
                                        error_spec_);
 }
 
 XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
   HloComputation::Builder builder(TestName());
 
-  std::unique_ptr<Literal> literal =
-      LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+  Literal literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
   // Reverse the minor-to-major order of the literal.
-  Layout* literal_layout =
-      literal->mutable_shape_do_not_use()->mutable_layout();
+  Layout* literal_layout = literal.mutable_shape_do_not_use()->mutable_layout();
   ASSERT_EQ(2, literal_layout->minor_to_major_size());
   literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
 
@@ -149,11 +145,11 @@
 
   auto module = CreateNewModule();
   module->AddEntryComputation(std::move(computation));
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+  Literal result = ExecuteAndTransfer(std::move(module), {});
 
   // The result of the computation has the default layout, which is the inverse
   // of the layout of the source literal.
-  LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+  LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, result,
                                        error_spec_);
 }
 
@@ -169,7 +165,7 @@
 
   HloComputation::Builder builder(TestName());
 
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+  Literal literal = LiteralUtil::CreateR3FromArray3D(a);
 
   HloInstruction* constant = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(literal)));
@@ -182,9 +178,9 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(std::move(computation));
   ForceResultLayout(module.get(), LayoutUtil::MakeLayout({1, 2, 0}));
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+  Literal result = ExecuteAndTransfer(std::move(module), {});
 
-  LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+  LiteralTestUtil::ExpectR3EqualArray3D(a, result);
 }
 
 void CopyOpTest::TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3,
@@ -203,7 +199,7 @@
 
   HloComputation::Builder builder(TestName());
 
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+  Literal literal = LiteralUtil::CreateR4FromArray4D(a);
 
   HloInstruction* constant = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(literal)));
@@ -216,9 +212,9 @@
   auto module = CreateNewModule();
   module->AddEntryComputation(std::move(computation));
   ForceResultLayout(module.get(), LayoutUtil::MakeLayout(permutation));
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+  Literal result = ExecuteAndTransfer(std::move(module), {});
 
-  LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+  LiteralTestUtil::ExpectR4EqualArray4D(a, result);
 }
 
 XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
@@ -250,11 +246,11 @@
 
   XlaBuilder builder(TestName());
   Parameter(&builder, 0, in_shape, "input");
-  auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
+  auto input_data = client_->TransferToServer(empty).ConsumeValueOrDie();
 
   auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
                     .ConsumeValueOrDie();
-  EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
+  EXPECT_TRUE(LiteralTestUtil::Equal(empty, actual));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index d12a4e7..410732c 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -46,7 +46,7 @@
   auto module =
       ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
   auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
-  EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
+  EXPECT_EQ(literal, ExecuteAndTransfer(std::move(module), {&literal}));
 }
 
 XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
@@ -68,9 +68,8 @@
       ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
   auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
   auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
-  EXPECT_EQ(
-      *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
-      *ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
+  EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+            ExecuteAndTransfer(std::move(module), {&literal0, &literal1}));
 }
 
 // On the GPU backend, constants get special handling.  Someone might pass a
@@ -95,8 +94,8 @@
       ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
   auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
   auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
-  EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
-            *ExecuteAndTransfer(std::move(module), {literal0.get()}));
+  EXPECT_EQ(LiteralUtil::MakeTuple({&literal0, &literal1}),
+            ExecuteAndTransfer(std::move(module), {&literal0}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index 6f7fc0e..a693fa3 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -80,8 +80,8 @@
 
   module->AddEntryComputation(builder.Build());
 
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
-  LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+  Literal result = ExecuteAndTransfer(std::move(module), {});
+  LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
 }
 
 XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
@@ -101,8 +101,8 @@
 
   module->AddEntryComputation(builder.Build());
 
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
-  LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+  Literal result = ExecuteAndTransfer(std::move(module), {});
+  LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
 }
 
 XLA_TEST_F(CustomCallTest,
@@ -125,9 +125,9 @@
 
   module->AddEntryComputation(b.Build());
 
-  std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
+  Literal result = ExecuteAndTransfer(std::move(module), {});
   LiteralTestUtil::ExpectR3EqualArray3D<float>(
-      Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+      Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
 }
 
 class CustomCallClientAPITest : public ClientLibraryTestBase {};
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index eb15fc0..e0f23b0 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -64,11 +64,11 @@
 
   // Try copying the elements back and comparing it
   auto handles = result_status.ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal;
+  Literal literal;
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 }
 
 TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
@@ -86,19 +86,19 @@
   auto handles1 = result_status1.ConsumeValueOrDie();
   auto handles2 = result_status2.ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal;
+  Literal literal;
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[0]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles1[1]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 
   handles1[0].reset();
   handles1[1].reset();
 
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[0]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles2[1]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
 }
 
 XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
@@ -116,15 +116,15 @@
   // the same as handle[3] and handle[1] should be the same as handle[2].
   auto handles = result_status.ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal;
+  Literal literal;
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[3]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 }
 
 TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
@@ -142,19 +142,19 @@
   // should not have been deallocated because of reference counting.
   global_data.reset();
 
-  std::unique_ptr<Literal> literal;
+  Literal literal;
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[0]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[1]));
-  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({2.0, 4.0, 6.0, 8.0}, literal);
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 
   /// Try deallocating one of the repeated elements, then copy
   handles[0].reset();
 
   TF_ASSERT_OK_AND_ASSIGN(literal, client_->Transfer(*handles[2]));
-  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, *literal);
+  LiteralTestUtil::ExpectR1Equal<float>({1.0, 2.0, 3.0, 4.0}, literal);
 }
 
 TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
@@ -170,10 +170,9 @@
 
 XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
-      LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+  Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
   Tuple(&builder, {p});
   auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 5873516..0171f51 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -68,16 +68,16 @@
   XlaOp param;
   auto param_data = CreateParameterAndTransferLiteral(
       0,
-      *LiteralUtil::MakeTuple(
-          {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
-           LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+      LiteralUtil::MakeTupleFromSlices(
+          {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}),
+           LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}})}),
       "arg0", &builder, &param);
   auto lhs = GetTupleElement(param, 0);
   auto rhs = GetTupleElement(param, 1);
   Dot(lhs, rhs);
 
   ComputeAndCompareLiteral(&builder,
-                           *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
+                           LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
                            {param_data.get()});
 }
 
@@ -196,11 +196,11 @@
 
   auto lhs_handle =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+          ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
               {{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
           .ConsumeValueOrDie();
   auto rhs_handle = this->client_
-                        ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
+                        ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
                             {{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
                         .ConsumeValueOrDie();
 
@@ -219,14 +219,14 @@
   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
     auto lhs_handle =
         client_
-            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
                 {{1.0f, 2.0f}, {3.0f, -4.0f}},
                 LayoutUtil::MakeLayout(
                     MinorToMajorForIsRowMajor(lhs_row_major))))
             .ConsumeValueOrDie();
     auto rhs_handle =
         client_
-            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
                 {{1.0f, 6.0f}, {7.0f, -4.0f}},
                 LayoutUtil::MakeLayout(
                     MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,24 +286,23 @@
 
   std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
-  std::unique_ptr<Literal> dot_lhs_lit =
-      LiteralUtil::CreateR2FromArray2DWithLayout(
-          *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
-                             param.dot_lhs_row_major)));
+  Literal dot_lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
+      *dot_lhs_data, LayoutUtil::MakeLayout(
+                         MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
   std::unique_ptr<GlobalData> dot_lhs_handle =
-      client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
+      client_->TransferToServer(dot_lhs_lit).ConsumeValueOrDie();
 
   std::unique_ptr<Array2D<NativeT>> dot_rhs_data =
       MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.k, param.n);
   Layout rhs_layout = LayoutUtil::MakeLayout(
       MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
-  std::unique_ptr<Literal> dot_rhs_lit =
+  Literal dot_rhs_lit =
       LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
   std::unique_ptr<GlobalData> dot_rhs_handle =
-      client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
+      client_->TransferToServer(dot_rhs_lit).ConsumeValueOrDie();
 
   std::unique_ptr<Array2D<NativeT>> addend_data;
-  std::unique_ptr<Literal> addend_lit;
+  Literal addend_lit;
   std::unique_ptr<GlobalData> addend_handle;
 
   if (param.has_addend) {
@@ -311,7 +310,7 @@
     addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
         *addend_data, LayoutUtil::MakeLayout(
                           MinorToMajorForIsRowMajor(param.addend_row_major)));
-    addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
+    addend_handle = client_->TransferToServer(addend_lit).ConsumeValueOrDie();
   }
 
   XlaBuilder builder(TestName());
@@ -477,14 +476,14 @@
   void TestImpl(bool lhs_row_major, bool rhs_row_major) {
     auto lhs_handle =
         client_
-            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
                 {{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
                 LayoutUtil::MakeLayout(
                     MinorToMajorForIsRowMajor(lhs_row_major))))
             .ConsumeValueOrDie();
     auto rhs_handle =
         client_
-            ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
+            ->TransferToServer(LiteralUtil::CreateFromArrayWithLayout<T>(
                 {{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
                 LayoutUtil::MakeLayout(
                     MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -511,12 +510,12 @@
 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
   auto lhs_handle =
       client_
-          ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+          ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
               {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
           .ConsumeValueOrDie();
   auto rhs_handle =
       client_
-          ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
+          ->TransferToServer(LiteralUtil::CreateR2WithLayout<complex64>(
               {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
               LayoutUtil::MakeLayout({1, 0})))
           .ConsumeValueOrDie();
@@ -584,7 +583,7 @@
   Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
 
   auto x_data = this->client_
-                    ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+                    ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
                         {{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
                           {{2000.0f, 200.0f}, {20.0f, 2.0f}}},
                          {{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -592,7 +591,7 @@
                     .ConsumeValueOrDie();
   auto y_data =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
                {{{11.0f, 22.0f}, {33.0f, 44.0f}},
                 {{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -630,13 +629,13 @@
 
   auto x_data =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+          ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
               {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
           .ConsumeValueOrDie();
 
   auto y_data =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
+          ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
               {{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
           .ConsumeValueOrDie();
 
@@ -668,7 +667,7 @@
 
   auto x_data =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
               {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
                {{{9.0f, 10.0f}, {11.0f, 12.0f}},
                 {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
@@ -676,7 +675,7 @@
 
   auto y_data =
       this->client_
-          ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+          ->TransferToServer(LiteralUtil::CreateR4FromArray4D<T>(
               {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
                {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
           .ConsumeValueOrDie();
@@ -708,14 +707,14 @@
         auto lhs_handle =
             this->client_
                 ->TransferToServer(
-                    *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+                    LiteralUtil::CreateR2FromArray2DWithLayout<T>(
                         *lhs, LayoutUtil::MakeLayout(
                                   MinorToMajorForIsRowMajor(row_major))))
                 .ConsumeValueOrDie();
         auto rhs_handle =
             this->client_
                 ->TransferToServer(
-                    *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+                    LiteralUtil::CreateR2FromArray2DWithLayout<T>(
                         *rhs, LayoutUtil::MakeLayout(
                                   MinorToMajorForIsRowMajor(row_major))))
                 .ConsumeValueOrDie();
@@ -778,15 +777,15 @@
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_0_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_1_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_2_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 
   Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
   this->template ComputeAndCompareR2<T>(
@@ -827,15 +826,15 @@
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_0_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_1_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
   TF_ASSERT_OK_AND_ASSIGN(
       auto arg_2_value,
       this->client_->TransferToServer(
-          *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
+          LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
 
   Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
   this->template ComputeAndCompareR2<T>(
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 9bf3767..7501c6d 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,13 +124,13 @@
     // vector<bool> is special so that it cannot be a Span<bool>, which
     // is what the code below wants. So instead we do this.
     Literal input_values =
-        std::move(*LiteralUtil::CreateR1(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        LiteralUtil::CreateR1(input_values_int)
+            .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+            .ValueOrDie();
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR1(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR1(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -150,13 +150,13 @@
              const std::vector<int64>& slice_sizes,
              const Array2D<int>& expected_values_int) {
     Literal input_values =
-        std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -176,13 +176,13 @@
              const std::vector<int64>& slice_sizes,
              const Array3D<int>& expected_values_int) {
     Literal input_values =
-        std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -359,17 +359,17 @@
   void RunR0(int input_value_int, int update_value_int,
              const std::vector<IndexT> slice_starts, int expected_value_int) {
     Literal input_value =
-        std::move(*LiteralUtil::CreateR0(input_value_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR0(input_value_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal update_value =
-        std::move(*LiteralUtil::CreateR0(update_value_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR0(update_value_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_value =
-        std::move(*LiteralUtil::CreateR0(expected_value_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR0(expected_value_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -390,17 +390,17 @@
              const std::vector<IndexT> slice_starts,
              absl::Span<const int> expected_values_int) {
     Literal input_values =
-        std::move(*LiteralUtil::CreateR1(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR1(input_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal update_values =
-        std::move(*LiteralUtil::CreateR1(update_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR1(update_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR1(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR1(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -421,17 +421,17 @@
              const std::vector<IndexT> slice_starts,
              const Array2D<int>& expected_values_int) {
     Literal input_values =
-        std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR2FromArray2D(input_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal update_values =
-        std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR2FromArray2D(update_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR2FromArray2D(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -452,17 +452,17 @@
              const std::vector<IndexT> slice_starts,
              const Array3D<int>& expected_values_int) {
     Literal input_values =
-        std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR3FromArray3D(input_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal update_values =
-        std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR3FromArray3D(update_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
     Literal expected_values =
-        std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
-                       ->Convert(primitive_util::NativeToPrimitiveType<DataT>())
-                       .ValueOrDie());
+        std::move(LiteralUtil::CreateR3FromArray3D(expected_values_int)
+                      .Convert(primitive_util::NativeToPrimitiveType<DataT>())
+                      .ValueOrDie());
 
     XlaBuilder builder(TestName());
     // Initialize and transfer dynamic slice start indices parameter.
@@ -529,9 +529,8 @@
 
   template <typename NativeT>
   void DumpArray(const string& name, const Array3D<NativeT> values) {
-    std::unique_ptr<Literal> literal =
-        LiteralUtil::CreateR3FromArray3D<NativeT>(values);
-    LOG(INFO) << name << ":" << literal->ToString();
+    Literal literal = LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+    LOG(INFO) << name << ":" << literal.ToString();
   }
 };
 
@@ -719,7 +718,7 @@
   auto input_literal = LiteralUtil::CreateR4(
       {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
         {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
-  auto input = ConstantLiteral(&builder, *input_literal);
+  auto input = ConstantLiteral(&builder, input_literal);
 
   // Create dynamic slice start indices as a parameter: shape [4]
   auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
@@ -740,7 +739,7 @@
   auto stream =
       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
   ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
-      stream.get(), *start_indices_literal, buffer));
+      stream.get(), start_indices_literal, buffer));
 
   std::unique_ptr<LocalExecutable> executable =
       client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 5116e60..b08ece0 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,7 +31,7 @@
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> input,
       client_->TransferToServer(
-          *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+          LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
 
   XlaBuilder b(TestName() + ".add");
   Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index bf1de02..51b50d4 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -38,29 +38,29 @@
 
     XlaBuilder builder(TestName());
 
-    std::unique_ptr<Literal> input_literal =
+    Literal input_literal =
         LiteralUtil::CreateFromDimensions(F32, {input_size});
     for (int64 i = begin; i < end; i++) {
       if (i >= known_incorrect_range.first &&
           i < known_incorrect_range.second) {
         // If the operation is known to be buggy on a specific input clamp that
         // input to 0 under the assumption that the op is at least correct on 0.
-        input_literal->Set({i - begin}, 0.0f);
+        input_literal.Set({i - begin}, 0.0f);
       } else {
-        input_literal->Set({i - begin}, tensorflow::bit_cast<float, int>(i));
+        input_literal.Set({i - begin}, tensorflow::bit_cast<float, int>(i));
       }
     }
 
     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
-                            client_->TransferToServer(*input_literal));
+                            client_->TransferToServer(input_literal));
 
-    auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+    auto input = Parameter(&builder, 0, input_literal.shape(), "input");
     enqueue_op(&builder, input);
 
     std::vector<float> expected_result;
     expected_result.reserve(input_size);
     for (int64 i = 0; i < input_size; i++) {
-      expected_result.push_back(evaluate_op(input_literal->Get<float>({i})));
+      expected_result.push_back(evaluate_op(input_literal.Get<float>({i})));
     }
 
     ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7cb2f0c..9c94acb 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -117,9 +117,9 @@
     auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
     auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
     if (primitive_util::IsFloatingPointType(prim_type)) {
-      EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
+      EXPECT_TRUE(LiteralTestUtil::Near(expected, actual, ErrorSpec(1e-4)));
     } else {
-      EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+      EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
     }
   }
 
@@ -222,8 +222,8 @@
           HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+      LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 // Test whether we emit appropriate code for parameters of fusion instructions.
@@ -248,8 +248,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+      LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -283,7 +283,7 @@
   // Every element of result should be y = x^2 = 4.0.
   for (int i = 0; i < rand_dim0_size; ++i) {
     for (int j = 0; j < dim1_size; ++j) {
-      EXPECT_EQ(4.0, result->Get<float>({i, j}));
+      EXPECT_EQ(4.0, result.Get<float>({i, j}));
     }
   }
 }
@@ -308,8 +308,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(LiteralTestUtil::Near(
-      *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
+      LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+      ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -323,8 +323,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(5),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -338,8 +338,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -353,8 +353,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -368,8 +368,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -383,8 +383,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR3<int32>({{{7}}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape__) {
@@ -398,8 +398,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(7),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -413,8 +413,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -428,8 +428,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -443,8 +443,8 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reverse) {
@@ -459,8 +459,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({3, 2, 1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -477,8 +477,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-3, -2, -1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -495,8 +495,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -1}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, SliceNegate) {
@@ -513,8 +513,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-1, -3}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -535,8 +535,8 @@
           HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({-2, -3}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -552,9 +552,9 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
                                 HloInstruction::FusionKind::kLoop);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, TransposeNegate) {
@@ -570,9 +570,9 @@
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
                                 HloInstruction::FusionKind::kLoop);
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -602,8 +602,8 @@
                                 HloInstruction::FusionKind::kInput);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(15),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -624,8 +624,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR0<int32>(-15),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -674,8 +674,8 @@
                                 HloInstruction::FusionKind::kLoop);
 
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+      ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 // When a constant (or other op) which has multiple users is imported
@@ -710,8 +710,8 @@
   EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
 
   EXPECT_TRUE(
-      LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
-                             *ExecuteAndTransfer(std::move(hlo_module), {})));
+      LiteralTestUtil::Equal(LiteralUtil::CreateR1<int32>({8}),
+                             ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
@@ -782,19 +782,17 @@
 }
 )";
 
-  std::unique_ptr<Literal> operand =
-      LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+  Literal operand = LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
   HloModuleConfig config;
   config.set_debug_options(GetDebugOptionsForTest());
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
                           ParseHloString(hlo_text, config));
-  TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
-      test_runner_.Execute(std::move(module), {operand.get()},
-                           /*run_hlo_passes=*/false));
+  TF_ASSERT_OK_AND_ASSIGN(Literal result,
+                          test_runner_.Execute(std::move(module), {&operand},
+                                               /*run_hlo_passes=*/false));
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
-      *result));
+      LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+      result));
 }
 
 class FusionClientLibraryTest : public ClientLibraryTestBase {};
@@ -821,16 +819,16 @@
   // where overflow is OK.
   Array2D<uint32> arr(32, 32);
   arr.FillUnique();
-  std::unique_ptr<Literal> l1 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+  Literal l1 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
       LayoutUtil::MakeLayout({0, 1}));
 
-  std::unique_ptr<Literal> l2 = LiteralUtil::CreateR2FromArray2D(arr)->Relayout(
+  Literal l2 = LiteralUtil::CreateR2FromArray2D(arr).Relayout(
       LayoutUtil::MakeLayout({1, 0}));
 
-  XlaOp p0 = AddParam(*l1, &b);
+  XlaOp p0 = AddParam(l1, &b);
   XlaOp sum = p0;
   for (int i = 1; i < kNumParams; ++i) {
-    auto pN = AddParam((i % 2 == 0 ? *l1 : *l2), &b);
+    auto pN = AddParam((i % 2 == 0 ? l1 : l2), &b);
     sum = sum + p0 * pN * pN;
   }
 
@@ -879,19 +877,19 @@
   auto param0_literal =
       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
   ScopedShapedBuffer buffer0 =
-      client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
+      client->LiteralToShapedBuffer(param0_literal, device_ordinal)
           .ConsumeValueOrDie();
 
   auto param1_literal =
       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
   ScopedShapedBuffer buffer1 =
-      client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
+      client->LiteralToShapedBuffer(param1_literal, device_ordinal)
           .ConsumeValueOrDie();
 
   auto param2_literal =
       LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
   ScopedShapedBuffer buffer2 =
-      client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
+      client->LiteralToShapedBuffer(param2_literal, device_ordinal)
           .ConsumeValueOrDie();
 
   // Build executable.
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 6d63498..daa8939 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -58,10 +58,10 @@
       slice_sizes={1, 3}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherV2) {
@@ -79,10 +79,10 @@
       slice_sizes={3, 1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherMultipleBatchDims) {
@@ -100,11 +100,10 @@
       slice_sizes={3, 1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_0) {
@@ -122,11 +121,11 @@
       slice_sizes={1, 1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
+  Literal start_indices =
       LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdMultipleBatchDims_1) {
@@ -144,11 +143,11 @@
       slice_sizes={1, 1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
+  Literal start_indices =
       LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherNd) {
@@ -166,13 +165,12 @@
       slice_sizes={1,1,2}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, TensorFlowGatherNdNonDefaultIndexVectorDim) {
@@ -190,13 +188,12 @@
       slice_sizes={1,1,2}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, DynamicSlice) {
@@ -214,10 +211,10 @@
       slice_sizes={1,1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, BatchDynamicSlice) {
@@ -235,11 +232,10 @@
       slice_sizes={1,1}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, ZeroDimBounds) {
@@ -257,9 +253,9 @@
       slice_sizes={1, 0}
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
@@ -281,11 +277,11 @@
   ROOT result = s32[6]{0} reshape(gather)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+  Literal start_indices = LiteralUtil::CreateR2<int32>(
       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
@@ -307,11 +303,11 @@
   ROOT result = s32[6]{0} reshape(gather)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<uint32>(
+  Literal start_indices = LiteralUtil::CreateR2<uint32>(
       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, NegativeIndex) {
@@ -333,11 +329,11 @@
   ROOT result = s32[6]{0} reshape(gather)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+  Literal start_indices = LiteralUtil::CreateR2<int32>(
       {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
@@ -359,11 +355,11 @@
   ROOT result = u32[6]{0} reshape(gather)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR2<int32>(
+  Literal start_indices = LiteralUtil::CreateR2<int32>(
       {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -381,10 +377,10 @@
       slice_sizes={1,3,2}
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+  Literal operand = LiteralUtil::CreateR3<int32>(
       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, ScalarResult) {
@@ -402,9 +398,9 @@
       slice_sizes={1}
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR0<int32>(1);
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+  Literal start_indices = LiteralUtil::CreateR0<int32>(1);
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, ZeroSizedResult) {
@@ -422,10 +418,10 @@
       slice_sizes={1, 3}
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) {
@@ -446,10 +442,10 @@
   ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({0, 2});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) {
@@ -470,11 +466,10 @@
   ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) {
@@ -495,11 +490,11 @@
   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
+  Literal start_indices =
       LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) {
@@ -520,13 +515,12 @@
   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest,
@@ -548,13 +542,12 @@
   ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) {
@@ -575,10 +568,10 @@
   ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices = LiteralUtil::CreateR1<int32>({1, 1});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) {
@@ -599,11 +592,10 @@
   ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted)
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> start_indices =
-      LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
-  RunTest(hlo_text, operand.get(), start_indices.get());
+  Literal start_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+  RunTest(hlo_text, &operand, &start_indices);
 }
 
 class GatherClientLibraryTest : public ClientLibraryTestBase {};
@@ -640,10 +632,10 @@
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> operand_arg,
       client_->TransferToServer(
-          *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+          LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> indices_arg,
-      client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
+      client_->TransferToServer(LiteralUtil::CreateR1<int32>({0, 2})));
   TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
                           client_->GetDeviceHandles(1));
   xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -657,10 +649,9 @@
   TF_ASSERT_OK_AND_ASSIGN(
       std::vector<std::unique_ptr<xla::GlobalData>> result_data,
       client_->ExecuteParallel(computation_instances));
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
+  TF_ASSERT_OK_AND_ASSIGN(Literal result_literal,
                           client_->Transfer(*(result_data[0])));
-  LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
-                                        *result_literal);
+  LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}}, result_literal);
 }
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index fc4c682..bdd4fd7 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -120,6 +120,14 @@
   return status_or;
 }
 
+/* static */
+PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) {
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      operands, PrecisionConfig::DEFAULT);
+  return precision_config;
+}
+
 DebugOptions HloTestBase::GetDebugOptionsForTest() {
   auto debug_options = legacy_flags::GetDebugOptionsFromFlags();
   // TODO(b/38354253): Change tests to use Parameters instead of Constants.
@@ -128,21 +136,21 @@
   return debug_options;
 }
 
-StatusOr<std::unique_ptr<Literal>> HloTestBase::Execute(
-    std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+StatusOr<Literal> HloTestBase::Execute(std::unique_ptr<HloModule> module,
+                                       absl::Span<Literal* const> arguments) {
   return test_runner_.Execute(std::move(module), arguments);
 }
 
-std::unique_ptr<Literal> HloTestBase::ExecuteNoHloPasses(
-    std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+                                        absl::Span<Literal* const> arguments) {
   return test_runner_
       .Execute(std::move(module), arguments,
                /*run_hlo_passes=*/false)
       .ValueOrDie();
 }
 
-std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
-    std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments) {
+Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+                                        absl::Span<Literal* const> arguments) {
   return test_runner_.Execute(std::move(module), arguments).ValueOrDie();
 }
 
@@ -180,7 +188,7 @@
   TF_ASSIGN_OR_RETURN(auto reference,
                       reference_runner_.Execute(std::move(reference_module),
                                                 arguments, run_hlo_passes));
-  return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test,
+  return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test,
                                       error);
 }
 
@@ -215,13 +223,12 @@
 ::testing::AssertionResult HloTestBase::RunAndCompare(
     std::unique_ptr<HloModule> module, const optional<ErrorSpec>& error,
     const std::function<void(HloModule*)>& reference_preprocessor) {
-  const auto& fake_arguments =
-      MakeFakeArguments(module.get()).ConsumeValueOrDie();
+  auto fake_arguments = MakeFakeArguments(module.get()).ConsumeValueOrDie();
 
   std::vector<Literal*> fake_argument_ptrs;
   absl::c_transform(
       fake_arguments, std::back_inserter(fake_argument_ptrs),
-      [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+      [](const Literal& literal) { return const_cast<Literal*>(&literal); });
 
   return RunAndCompare(std::move(module), fake_argument_ptrs, error,
                        reference_preprocessor);
@@ -235,7 +242,7 @@
   std::vector<Literal*> fake_argument_ptrs;
   absl::c_transform(
       fake_arguments, std::back_inserter(fake_argument_ptrs),
-      [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+      [](const Literal& literal) { return const_cast<Literal*>(&literal); });
 
   return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error,
                                   reference_preprocessor);
@@ -269,7 +276,7 @@
   std::vector<Literal*> fake_argument_ptrs;
   absl::c_transform(
       fake_arguments, std::back_inserter(fake_argument_ptrs),
-      [](const std::unique_ptr<Literal>& literal) { return literal.get(); });
+      [](const Literal& literal) { return const_cast<Literal*>(&literal); });
   return test_runner_
                  .Execute(std::move(module_or_status.ValueOrDie()),
                           fake_argument_ptrs, /*run_hlo_passes=*/true)
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 4c88257..0ae4bdc 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -80,6 +80,8 @@
   static StatusOr<bool> RunHloPass(HloPassInterface* hlo_pass,
                                    HloModule* module);
 
+  static PrecisionConfig DefaultPrecisionConfig(int operands);
+
  protected:
   // This uses the interpreter backend as the reference backend and
   // automatically finds another supported backend as the test backend. If the
@@ -113,16 +115,16 @@
   }
 
   // Executes the given module and return the result as a Literal.
-  StatusOr<std::unique_ptr<Literal>> Execute(
-      std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+  StatusOr<Literal> Execute(std::unique_ptr<HloModule> module,
+                            absl::Span<Literal* const> arguments);
 
   // Same as above, except the module will be executed without running any HLO
   // passes on it.
-  std::unique_ptr<Literal> ExecuteNoHloPasses(
-      std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+  Literal ExecuteNoHloPasses(std::unique_ptr<HloModule> module,
+                             absl::Span<Literal* const> arguments);
 
-  std::unique_ptr<Literal> ExecuteAndTransfer(
-      std::unique_ptr<HloModule> module, absl::Span<Literal* const> arguments);
+  Literal ExecuteAndTransfer(std::unique_ptr<HloModule> module,
+                             absl::Span<Literal* const> arguments);
 
   // Executes the given hlo module on two backends and compares results.
   //
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index 96f7221..43cca91 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -155,20 +155,20 @@
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
                                                  const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR0<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Equal(
     absl::Span<const NativeT> expected, const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR1<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Equal(
     std::initializer_list<std::initializer_list<NativeT>> expected,
     const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR2<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
@@ -176,46 +176,46 @@
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
     const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR3<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
     const Array2D<NativeT>& expected, const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR2FromArray2D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
     const Array3D<NativeT>& expected, const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR3FromArray3D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
     const Array4D<NativeT>& expected, const LiteralSlice& actual) {
-  EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
+  EXPECT_TRUE(Equal(LiteralUtil::CreateR4FromArray4D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
                                                 const LiteralSlice& actual,
                                                 const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR0<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Near(
     absl::Span<const NativeT> expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR1<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Near(
     std::initializer_list<std::initializer_list<NativeT>> expected,
     const LiteralSlice& actual, const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR2<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
@@ -223,7 +223,7 @@
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
     const LiteralSlice& actual, const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR3<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
@@ -232,28 +232,28 @@
         std::initializer_list<std::initializer_list<NativeT>>>>
         expected,
     const LiteralSlice& actual, const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR4<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
     const Array2D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR2FromArray2D(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
     const Array3D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR3FromArray3D(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
     const Array4D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
+  EXPECT_TRUE(Near(LiteralUtil::CreateR4FromArray4D(expected), actual, error));
 }
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index 4151bfa..b6f9b81 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,11 +31,11 @@
 namespace {
 
 TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
-  std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
-      LiteralUtil::CreateR0<int32>(42).get(),
-      LiteralUtil::CreateR0<int32>(64).get(),
+  Literal literal = LiteralUtil::MakeTupleFromSlices({
+      LiteralUtil::CreateR0<int32>(42),
+      LiteralUtil::CreateR0<int32>(64),
   });
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, literal));
 }
 
 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -43,15 +43,15 @@
   // un-fail an assertion failure. The CHECK-failure is death, so we can make a
   // death assertion.
   auto unequal_things_are_equal = [] {
-    std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
-        LiteralUtil::CreateR0<int32>(42).get(),
-        LiteralUtil::CreateR0<int32>(64).get(),
+    Literal lhs = LiteralUtil::MakeTupleFromSlices({
+        LiteralUtil::CreateR0<int32>(42),
+        LiteralUtil::CreateR0<int32>(64),
     });
-    std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
-        LiteralUtil::CreateR0<int32>(64).get(),
-        LiteralUtil::CreateR0<int32>(42).get(),
+    Literal rhs = LiteralUtil::MakeTupleFromSlices({
+        LiteralUtil::CreateR0<int32>(64),
+        LiteralUtil::CreateR0<int32>(42),
     });
-    CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+    CHECK(LiteralTestUtil::Equal(lhs, rhs)) << "LHS and RHS are unequal";
   };
   ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
 }
@@ -61,7 +61,7 @@
     auto two = LiteralUtil::CreateR0<float>(2);
     auto four = LiteralUtil::CreateR0<float>(4);
     ErrorSpec error(0.001);
-    CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+    CHECK(LiteralTestUtil::Near(two, four, error)) << "two is not near four";
   };
 
   tensorflow::Env* env = tensorflow::Env::Default();
@@ -86,14 +86,14 @@
     LiteralProto literal_proto;
     TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
                                             &literal_proto));
-    std::unique_ptr<Literal> literal =
+    Literal literal =
         Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
     if (result.find("expected") != string::npos) {
-      EXPECT_EQ("2", literal->ToString());
+      EXPECT_EQ("2", literal.ToString());
     } else if (result.find("actual") != string::npos) {
-      EXPECT_EQ("4", literal->ToString());
+      EXPECT_EQ("4", literal.ToString());
     } else if (result.find("mismatches") != string::npos) {
-      EXPECT_EQ("true", literal->ToString());
+      EXPECT_EQ("true", literal.ToString());
     } else {
       FAIL() << "unknown file in temporary directory: " << result;
     }
@@ -103,8 +103,7 @@
 TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
   auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
   auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
-  ::testing::AssertionResult result =
-      LiteralTestUtil::Equal(*expected, *actual);
+  ::testing::AssertionResult result = LiteralTestUtil::Equal(expected, actual);
   EXPECT_THAT(result.message(),
               ::testing::HasSubstr("Expected literal:\n{1, 2, 3}"));
   EXPECT_THAT(result.message(),
@@ -116,7 +115,7 @@
       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
   auto b = LiteralUtil::CreateR1<float>(
       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
-  EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+  EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
 }
 
 TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
@@ -124,7 +123,7 @@
       {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
   auto b = LiteralUtil::CreateR1<float>(
       {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
-  EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
+  EXPECT_TRUE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
 }
 
 TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
@@ -132,8 +131,8 @@
       {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
   auto b =
       LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
-  EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
-  EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
+  EXPECT_FALSE(LiteralTestUtil::Near(a, b, ErrorSpec{0.0001}));
+  EXPECT_FALSE(LiteralTestUtil::Near(b, a, ErrorSpec{0.0001}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 237a4a3..dbdd20d 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -45,7 +45,7 @@
   TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
 
   auto x_array =
-      LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+      LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
 
   int64 allocation_count_before = allocator_->allocation_count();
 
@@ -58,7 +58,7 @@
                           DefaultExecutableBuildOptions(), options);
 
   LiteralTestUtil::ExpectR1Near<float>(
-      {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_);
+      {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(*result), error_spec_);
 
   // At least one allocation should have been performed when executing the
   // computation.
@@ -92,7 +92,7 @@
         computation, {}, ExecutableBuildOptions().set_device_ordinal(d),
         ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator));
     LiteralTestUtil::ExpectR1Near<float>(
-        {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+        {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
 
     // At least one allocation should have been performed when executing the
     // computation.
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 1a823cf..a99b43f 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -58,7 +58,7 @@
 
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
-  LiteralTestUtil::ExpectR0Near<float>(123.f, *ShapedBufferToLiteral(result),
+  LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
                                        error_spec_);
 }
 
@@ -68,10 +68,10 @@
   auto y = ConstantR0<float>(&builder, 123.0f);
   Add(x, y);
 
-  auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
+  auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
-  LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
+  LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
                                        error_spec_);
 }
 
@@ -81,10 +81,10 @@
   auto y = ConstantR1<float>(&builder, {});
   Add(x, y);
 
-  auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
+  auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
-  LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
+  LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
                                        error_spec_);
 }
 
@@ -95,11 +95,11 @@
   Add(x, y);
 
   auto x_array =
-      LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+      LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
   LiteralTestUtil::ExpectR1Near<float>(
-      {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+      {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
 }
 
 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
@@ -109,14 +109,14 @@
   Add(x, y);
 
   auto x_array =
-      LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+      LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
   ExecutionProfile profile;
   ScopedShapedBuffer result = ExecuteLocallyOrDie(
       builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
       DefaultExecutableRunOptions().set_execution_profile(&profile));
 
   LiteralTestUtil::ExpectR1Near<float>(
-      {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+      {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
   EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
 }
 
@@ -128,13 +128,13 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   // Create x as a col-major array.
-  auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+  auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
   EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
                                 LayoutUtil::MakeLayout({0, 1})));
 
   // Create y as a row-major array.
-  auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
+  auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
       {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
   EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
                                 LayoutUtil::MakeLayout({1, 0})));
@@ -142,15 +142,15 @@
   ScopedShapedBuffer result_colmaj =
       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
-                                       *ShapedBufferToLiteral(result_colmaj),
+                                       ShapedBufferToLiteral(result_colmaj),
                                        error_spec_);
 
   // Run with the parameter values in a different order.
   ScopedShapedBuffer result_param_swap =
       ExecuteLocallyOrDie(computation, {&y_array, &x_array});
-  LiteralTestUtil::ExpectR2Near<float>(
-      {{11.0f, 22.0f}, {33.0f, 44.0f}},
-      *ShapedBufferToLiteral(result_param_swap), error_spec_);
+  LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
+                                       ShapedBufferToLiteral(result_param_swap),
+                                       error_spec_);
 }
 
 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
@@ -161,9 +161,9 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   auto x_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+      LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
   auto y_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+      LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
 
   // Run with col-major result layout.
   ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -174,7 +174,7 @@
   EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(),
                                 LayoutUtil::MakeLayout({0, 1})));
   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
-                                       *ShapedBufferToLiteral(result_colmaj),
+                                       ShapedBufferToLiteral(result_colmaj),
                                        error_spec_);
 
   // Run with row-major result layout.
@@ -186,7 +186,7 @@
   EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(),
                                 LayoutUtil::MakeLayout({1, 0})));
   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
-                                       *ShapedBufferToLiteral(result_rowmaj),
+                                       ShapedBufferToLiteral(result_rowmaj),
                                        error_spec_);
 }
 
@@ -198,9 +198,9 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   auto x_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+      LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
   auto y_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+      LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
 
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -208,13 +208,13 @@
   EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
   EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
 
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {0}));
+                                        LiteralSlice(result_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
-                                        LiteralSlice(*result_literal, {1}));
+                                        LiteralSlice(result_literal, {1}));
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {2}));
+                                        LiteralSlice(result_literal, {2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
@@ -226,9 +226,9 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   auto x_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+      LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
   auto y_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+      LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
 
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -236,15 +236,15 @@
   EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
 
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {1}));
+                                        LiteralSlice(result_literal, {1}));
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {0, 0}));
+                                        LiteralSlice(result_literal, {0, 0}));
   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
-                                        LiteralSlice(*result_literal, {0, 1}));
+                                        LiteralSlice(result_literal, {0, 1}));
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {0, 2}));
+                                        LiteralSlice(result_literal, {0, 2}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
@@ -255,7 +255,7 @@
   Tuple(&builder, {x, y});
 
   auto array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+      LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
 
   ExecutableBuildOptions options = DefaultExecutableBuildOptions();
   Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -268,11 +268,11 @@
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
                           options, DefaultExecutableRunOptions());
 
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {0}));
+                                        LiteralSlice(result_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                        LiteralSlice(*result_literal, {1}));
+                                        LiteralSlice(result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -298,15 +298,15 @@
   Tuple(&builder, {array_sum, vector_diff});
   auto computation = builder.Build().ConsumeValueOrDie();
 
-  auto x_literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
-       LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
-  auto y_literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
-       LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+  auto x_literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+       LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
+  auto y_literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
+       LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
 
-  auto x_buffer = LiteralToShapedBuffer(*x_literal);
-  auto y_buffer = LiteralToShapedBuffer(*y_literal);
+  auto x_buffer = LiteralToShapedBuffer(x_literal);
+  auto y_buffer = LiteralToShapedBuffer(y_literal);
 
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
@@ -314,11 +314,11 @@
   EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape()));
   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
 
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
-                                        LiteralSlice(*result_literal, {0}));
+                                        LiteralSlice(result_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
-                                        LiteralSlice(*result_literal, {1}));
+                                        LiteralSlice(result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -344,21 +344,20 @@
   Tuple(&builder, {negate_array, vector_sum});
   auto computation = builder.Build().ConsumeValueOrDie();
 
-  auto arg_literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
-            LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
-           .get(),
-       LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
-  auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+  auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::MakeTupleFromSlices(
+           {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+            LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
+       LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
+  auto arg_buffer = LiteralToShapedBuffer(arg_literal);
 
   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
 
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
-                                        LiteralSlice(*result_literal, {0}));
+                                        LiteralSlice(result_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
-                                        LiteralSlice(*result_literal, {1}));
+                                        LiteralSlice(result_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -377,24 +376,24 @@
   Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
   auto computation = builder.Build().ConsumeValueOrDie();
 
-  auto arg_literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
-       LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
-  auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+  auto arg_literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+       LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
+  auto arg_buffer = LiteralToShapedBuffer(arg_literal);
 
   ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
-  std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
+  Literal result_0_literal = ShapedBufferToLiteral(result_0);
   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
-                                        LiteralSlice(*result_0_literal, {0}));
+                                        LiteralSlice(result_0_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
-                                        LiteralSlice(*result_0_literal, {1}));
+                                        LiteralSlice(result_0_literal, {1}));
 
   ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
-  std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
+  Literal result_1_literal = ShapedBufferToLiteral(result_1);
   LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                        LiteralSlice(*result_1_literal, {0}));
+                                        LiteralSlice(result_1_literal, {0}));
   LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
-                                        LiteralSlice(*result_1_literal, {1}));
+                                        LiteralSlice(result_1_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -427,20 +426,19 @@
 
   // Feed in a tuple where each two-element vector element is {tuple_index,
   // -tuple_index}.
-  std::vector<std::unique_ptr<Literal>> arg_elements;
+  std::vector<Literal> arg_elements;
   for (int i = 0; i < kElementCount; ++i) {
     arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
   }
-  std::unique_ptr<Literal> arg_literal =
-      LiteralUtil::MakeTupleOwned(std::move(arg_elements));
-  auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+  Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
+  auto arg_buffer = LiteralToShapedBuffer(arg_literal);
 
   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
 
   for (int i = 0; i < kElementCount; ++i) {
     LiteralTestUtil::ExpectR1Near<float>(
-        {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
+        {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
   }
 }
 
@@ -476,9 +474,9 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   // Construct the argument to pass to the computation.
-  std::vector<std::unique_ptr<Literal>> outer_tuple_elements;
+  std::vector<Literal> outer_tuple_elements;
   for (int i = 0; i < kFanout; ++i) {
-    std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
+    std::vector<Literal> inner_tuple_elements;
     for (int j = 0; j < kFanout; ++j) {
       inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
     }
@@ -487,16 +485,16 @@
   }
   auto arg_literal =
       LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
-  auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+  auto arg_buffer = LiteralToShapedBuffer(arg_literal);
 
   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
 
   for (int i = 0; i < kFanout; ++i) {
     for (int j = 0; j < kFanout; ++j) {
-      LiteralTestUtil::ExpectR0Near<float>(
-          i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}),
-          error_spec_);
+      LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
+                                           LiteralSlice(result_literal, {i, j}),
+                                           error_spec_);
     }
   }
 }
@@ -525,23 +523,23 @@
   auto computation = builder.Build().ConsumeValueOrDie();
 
   // Construct the argument to pass to the computation.
-  std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
+  Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
   for (int i = 0; i < kTupleDepth; ++i) {
-    std::vector<std::unique_ptr<Literal>> arg_vector;
+    std::vector<Literal> arg_vector;
     arg_vector.push_back(std::move(arg_literal));
     arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
   }
-  auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
+  auto arg_buffer = LiteralToShapedBuffer(arg_literal);
 
   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
-  std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
+  Literal result_literal = ShapedBufferToLiteral(result);
 
   ShapeIndex index;
   for (int i = 0; i < kTupleDepth; ++i) {
     index.push_back(0);
   }
   LiteralTestUtil::ExpectR0Equal<float>(165.0,
-                                        LiteralSlice(*result_literal, index));
+                                        LiteralSlice(result_literal, index));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
@@ -552,7 +550,7 @@
   Add(x, y);
 
   auto x_array =
-      LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+      LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
   auto execute_status =
       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
 
@@ -568,7 +566,7 @@
   Neg(x);
 
   auto x_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+      LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
   auto execute_status =
       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
 
@@ -585,7 +583,7 @@
   Neg(x);
 
   auto x_array = LiteralToShapedBuffer(
-      *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+      LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
   auto execute_status = ExecuteLocally(
       builder.Build().ValueOrDie(), {&x_array},
       DefaultExecutableBuildOptions().set_result_layout(
@@ -622,7 +620,7 @@
           DefaultExecutableRunOptions().set_device_ordinal(d));
       EXPECT_EQ(d, result.device_ordinal());
       LiteralTestUtil::ExpectR0Equal<float>(42.0f,
-                                            *ShapedBufferToLiteral(result));
+                                            ShapedBufferToLiteral(result));
     }
   }
 }
@@ -666,8 +664,7 @@
     // As a check to verify that the computation ran of the device associated
     // with the stream. This is a weak check, but stronger verification is hard.
     EXPECT_EQ(d, result.device_ordinal());
-    LiteralTestUtil::ExpectR0Equal<float>(42.0f,
-                                          *ShapedBufferToLiteral(result));
+    LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
   }
 }
 
@@ -745,11 +742,11 @@
 
   ScopedShapedBuffer result =
       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
-  std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
+  Literal tuple_literal = ShapedBufferToLiteral(result);
   LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
-                                        LiteralSlice(*tuple_literal, {0}));
+                                        LiteralSlice(tuple_literal, {0}));
   LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
-                                        LiteralSlice(*tuple_literal, {1}));
+                                        LiteralSlice(tuple_literal, {1}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
@@ -768,7 +765,7 @@
       executable_status.ConsumeValueOrDie();
 
   auto x_array =
-      LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+      LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
   ScopedShapedBuffer result =
       executable->Run({&x_array}, DefaultExecutableRunOptions())
           .ConsumeValueOrDie();
@@ -778,7 +775,7 @@
                    ->BlockHostUntilDone());
 
   LiteralTestUtil::ExpectR1Near<float>(
-      {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_);
+      {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
 }
 
 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
@@ -792,33 +789,33 @@
     TF_ASSERT_OK_AND_ASSIGN(
         auto transferred_literal,
         local_client_->ShapedBufferToLiteral(shaped_buffer));
-    EXPECT_EQ(literal, *transferred_literal);
+    EXPECT_EQ(literal, transferred_literal);
   };
 
   // Array shapes.
-  test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
-  test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
-  test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
+  test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
+  test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
+  test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
   test_to_device_and_back(
-      *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
-  test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
+      LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+  test_to_device_and_back(LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
 
   // Null shape (empty tuple).
-  test_to_device_and_back(*LiteralUtil::MakeTuple({}));
+  test_to_device_and_back(LiteralUtil::MakeTuple({}));
 
   // Non-nested tuples.
-  test_to_device_and_back(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
-  test_to_device_and_back(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
-                               LiteralUtil::CreateR0<float>(123456.0).get()}));
+  test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(12223.0)}));
+  test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+       LiteralUtil::CreateR0<float>(123456.0)}));
 
   // Nested tuple.
-  test_to_device_and_back(*LiteralUtil::MakeTuple(
-      {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
-                               LiteralUtil::CreateR0<float>(123456.0).get()})
-           .get(),
-       LiteralUtil::CreateR0<bool>(false).get()}));
+  test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::MakeTupleFromSlices(
+           {LiteralUtil::CreateR1<float>({1.0, -42.0}),
+            LiteralUtil::CreateR0<float>(123456.0)}),
+       LiteralUtil::CreateR0<bool>(false)}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -832,17 +829,17 @@
     TF_ASSERT_OK_AND_ASSIGN(
         auto transferred_literal,
         local_client_->ShapedBufferToLiteral(shaped_buffer));
-    EXPECT_EQ(literal, *transferred_literal);
+    EXPECT_EQ(literal, transferred_literal);
   };
 
   test_to_device_and_back(
-      *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
-  test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
+      LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+  test_to_device_and_back(LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
   test_to_device_and_back(
-      *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
-  test_to_device_and_back(*LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
-       LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+      LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+  test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<double>({1.0, -42.0}),
+       LiteralUtil::CreateR0<int64>(123456789000LL)}));
 }
 
 XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
@@ -852,7 +849,7 @@
   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
   Add(in, constant);
 
-  std::unique_ptr<Literal> result;
+  Literal result;
   std::unique_ptr<tensorflow::Thread> thread(
       tensorflow::Env::Default()->StartThread(
           tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -861,13 +858,13 @@
           }));
 
   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
-      *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+      LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
       local_client_->default_device_ordinal()));
 
   // Join the thread.
   thread.reset();
 
-  LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+  LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
 }
 
 XLA_TEST_F(LocalClientExecuteTest, InfeedOutfeedTest) {
@@ -884,14 +881,14 @@
           [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
 
   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
-      *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+      LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
       local_client_->default_device_ordinal()));
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+  TF_ASSERT_OK_AND_ASSIGN(Literal result,
                           local_client_->TransferFromOutfeedLocal(
                               shape, local_client_->default_device_ordinal()));
 
-  LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
+  LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
 }
 
 // Benchmark that measures the overhead of the LocalClient API when running a
@@ -922,8 +919,8 @@
   auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
   auto stream =
       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
-  ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
-                                                         buffer));
+  ASSERT_IS_OK(
+      transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
 
   const int kWarmups = 2;
 
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index a8c68fc..f90ef22 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -136,7 +136,7 @@
       .ConsumeValueOrDie();
 }
 
-std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+Literal LocalClientTestBase::ShapedBufferToLiteral(
     const ShapedBuffer& shaped_buffer) {
   return local_client_->ShapedBufferToLiteral(shaped_buffer)
       .ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 90095c5..4027c7b 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -86,8 +86,7 @@
 
   // Construct and return a literal containing the array represented by
   // shaped_buffer.
-  std::unique_ptr<Literal> ShapedBufferToLiteral(
-      const ShapedBuffer& shaped_buffer);
+  Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
 
   // Execute the given computation on the local client. With and without
   // options.
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 0732e19..4d327a6 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -169,11 +169,11 @@
 TEST_F(MapTest, MapEachElemPlusOneR0) {
   // Applies lambda (x) (+ x 1)) to an input scalar.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(42.0);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateAdderToOne(), {});
 
   ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
@@ -183,11 +183,11 @@
 XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+  Literal param0_literal = LiteralUtil::CreateR1<float>({});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateAdderToOne(), {0});
 
   ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
@@ -197,12 +197,12 @@
 TEST_F(MapTest, MapEachElemPlusOneR1S4) {
   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateAdderToOne(), {0});
 
   ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
@@ -211,12 +211,12 @@
 
 TEST_F(MapTest, MapEachF32ElementToS32Constant) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateScalarOne<int32>(), {0});
 
   ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -224,12 +224,12 @@
 
 TEST_F(MapTest, MapEachF32ElementToU32Constant) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
 
   ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
@@ -238,12 +238,12 @@
 TEST_F(MapTest, MapEachElemLongerChainR1) {
   // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
 
   ComputeAndCompareR1<float>(
@@ -255,11 +255,11 @@
   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
   // maps (lambda (x) (* x 2)) on the result.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+  Literal param0_literal = LiteralUtil::CreateR1<float>({});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
   Map(&builder, {map1}, CreateMulByTwo(), {0});
 
@@ -271,12 +271,12 @@
   // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
   // maps (lambda (x) (* x 2)) on the result.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
   Map(&builder, {map1}, CreateMulByTwo(), {0});
 
@@ -287,12 +287,12 @@
 TEST_F(MapTest, MapEachElemPlusOneR2) {
   // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+  Literal param0_literal = LiteralUtil::CreateR2<float>(
       {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param}, CreateAdderToOne(), {0, 1});
 
   Array2D<float> expected_array(
@@ -342,17 +342,17 @@
 TEST_F(MapTest, MapBinaryAdder) {
   // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
-  std::unique_ptr<Literal> param1_literal =
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+  Literal param1_literal =
       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
       {0});
 
@@ -365,18 +365,18 @@
 // for Map that used to fail in shape inference (b/28989438).
 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
+  Literal param0_literal = LiteralUtil::CreateR2WithLayout(
       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
+  Literal param1_literal = LiteralUtil::CreateR2WithLayout(
       {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
       {0, 1});
 
@@ -391,18 +391,18 @@
 
 XLA_TEST_F(MapTest, AddR3_3x0x2) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> param1_literal =
+  Literal param1_literal =
       LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
       {0, 1, 2});
 
@@ -413,22 +413,22 @@
 TEST_F(MapTest, MapTernaryAdder) {
   // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
-  std::unique_ptr<Literal> param1_literal =
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+  Literal param1_literal =
       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
-  std::unique_ptr<Literal> param2_literal =
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
+  Literal param2_literal =
       LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
   std::unique_ptr<GlobalData> param2_data =
-      client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param2_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
-  auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
+  auto param2 = Parameter(&builder, 2, param2_literal.shape(), "param2");
   Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
 
   ComputeAndCompareR1<float>(
@@ -475,17 +475,17 @@
   Add(x, y);
   auto error_add = sub_builder->BuildAndNoteError();
 
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
-  std::unique_ptr<Literal> param1_literal =
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
+  Literal param1_literal =
       LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, error_add, {0});
 
   StatusOr<XlaComputation> computation_status = builder.Build();
@@ -513,15 +513,15 @@
   Pow(x, y);
   auto power = sub_builder->BuildAndNoteError();
 
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
-  std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+  Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, power, {});
 
   ComputeAndCompareR0<float>(&builder, 32.0f,
@@ -540,15 +540,15 @@
   Sub(y, x);  // note that this is y - x, not x - y
   auto sub_opposite = sub_builder->BuildAndNoteError();
 
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
-  std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+  Literal param1_literal = LiteralUtil::CreateR0<float>(5.0f);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param1_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, param1_literal.shape(), "param1");
   Map(&builder, {param0, param1}, sub_opposite, {});
 
   ComputeAndCompareR0<float>(
@@ -565,11 +565,11 @@
   Mul(x, x);
   auto square = sub_builder->BuildAndNoteError();
 
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(10.0f);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
   Map(&builder, {param0}, square, {});
 
   ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index edb592f..3f27811 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -63,11 +63,11 @@
                                                  });
   Exp(data);
 
-  std::unique_ptr<Literal> expected =
+  Literal expected =
       LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f},    // row 0
                                            {0.36788f, 1.64872f}});  // row 1
 
-  this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+  this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
 }
 
 XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
@@ -92,10 +92,10 @@
                                                  });
   Map(&builder, {data}, add_half, {0, 1});
 
-  std::unique_ptr<Literal> expected =
+  Literal expected =
       LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f},     // row 0
                                            {-0.5f, 1.0f}});  // row 1
-  this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+  this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-5));
 }
 
 XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
@@ -111,10 +111,10 @@
                                                 });
   Max(lhs, rhs);
 
-  std::unique_ptr<Literal> expected =
+  Literal expected =
       LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f},     // row 0
                                            {3.0f, -4.0f}});  // row 1
-  this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+  this->ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6));
 }
 
 struct TestLinspaceMaxParam {
@@ -200,14 +200,12 @@
 
     TF_ASSERT_OK_AND_ASSIGN(
         auto lhs_handle,
-        client_->TransferToServer(
-            *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
-                lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+        client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+            lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
     TF_ASSERT_OK_AND_ASSIGN(
         auto rhs_handle,
-        client_->TransferToServer(
-            *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
-                rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+        client_->TransferToServer(LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+            rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
 
     XlaBuilder builder(TestName());
     auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 05f90ba..56aaeb0 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -47,7 +47,6 @@
 namespace xla {
 namespace {
 
-
 class MultiOutputFusionTest : public HloTestBase {
  protected:
   MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; }
@@ -90,8 +89,8 @@
     DotDimensionNumbers dot_dnums;
     dot_dnums.add_lhs_contracting_dimensions(1);
     dot_dnums.add_rhs_contracting_dimensions(0);
-    HloInstruction* dot = builder.AddInstruction(
-        HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums));
+    HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+        elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2)));
     auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
 
     if (manual_fusion) {
@@ -115,10 +114,10 @@
 
     Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size}));
     expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
+    Literal literal_r0 = LiteralUtil::CreateR0<float>(-9.0f);
     auto actual =
-        ExecuteAndTransfer(std::move(hlo_module),
-                           {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
-    EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+        ExecuteAndTransfer(std::move(hlo_module), {&literal_r0, &arg1});
+    EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
   }
 
   void RunTest1D(bool manual_fusion, int size) {
@@ -154,7 +153,7 @@
     dot_dnums.add_rhs_contracting_dimensions(0);
     HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
         ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape,
-        dot_dnums));
+        dot_dnums, DefaultPrecisionConfig(2)));
     auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
 
     if (manual_fusion) {
@@ -179,10 +178,9 @@
     Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}));
     input1.PopulateWithValue(1.);
 
-    Literal expect =
-        std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
+    Literal expect = LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f});
     auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
-    EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
+    EXPECT_TRUE(LiteralTestUtil::Near(expect, actual, error_spec_));
   }
 };
 
@@ -219,10 +217,9 @@
           LiteralUtil::CreateR0<float>(1.0)),
       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
                                   LiteralUtil::CreateR0<int32>(4)));
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
+      LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -248,9 +245,8 @@
       HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
           .ValueOrDie();
   auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
-  LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+  LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, result);
 }
 
 XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -281,9 +277,8 @@
       HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
           .ValueOrDie();
   auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
-  LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
+  LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, result);
 }
 
 const char* const kScalarOps = R"(
@@ -325,13 +320,12 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -357,13 +351,12 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
           LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -390,13 +383,12 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
-                                   LiteralUtil::CreateR1<float>({36, 64}),
-                                   LiteralUtil::CreateR1<float>({66, 138})),
-      *result));
+      LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+                                  LiteralUtil::CreateR1<float>({36, 64}),
+                                  LiteralUtil::CreateR1<float>({66, 138})),
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -423,14 +415,13 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -457,15 +448,14 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
           LiteralUtil::CreateR3<float>(
               {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
           LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -493,16 +483,15 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR1<float>({14, 22}),
           LiteralUtil::CreateR3<float>(
               {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
           LiteralUtil::CreateR3<float>(
               {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -531,13 +520,13 @@
       LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
   auto init1 = LiteralUtil::CreateR0<float>(5);
   auto init2 = LiteralUtil::CreateR0<float>(6);
-  std::unique_ptr<Literal> result = ExecuteNoHloPasses(
-      std::move(module), {param.get(), init1.get(), init2.get()});
+  Literal result =
+      ExecuteNoHloPasses(std::move(module), {&param, &init1, &init2});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
           LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
-      *result));
+      result));
 }
 
 XLA_TEST_F(MultiOutputFusionTest,
@@ -566,10 +555,9 @@
   auto param = LiteralUtil::CreateR3<Eigen::half>(
       {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
        {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
-  std::unique_ptr<Literal> result =
-      ExecuteNoHloPasses(std::move(module), {param.get()});
+  Literal result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(
+      LiteralUtil::MakeTupleOwned(
           LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
           LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
           LiteralUtil::CreateR3<Eigen::half>(
@@ -577,7 +565,7 @@
                 {Eigen::half(3), Eigen::half(4)}},
                {{Eigen::half(5), Eigen::half(6)},
                 {Eigen::half(7), Eigen::half(8)}}})),
-      *result));
+      result));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
index 0a0426a..f246082 100644
--- a/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
+++ b/tensorflow/compiler/xla/tests/outfeed_in_nested_computation_test.cc
@@ -70,7 +70,7 @@
   GetTupleElement(result_tuple, 0);
   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
 
-  std::unique_ptr<xla::Literal> comp_result;
+  Literal comp_result;
   std::unique_ptr<tensorflow::Thread> thread(
       tensorflow::Env::Default()->StartThread(
           tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -81,41 +81,41 @@
   VLOG(1) << "Transferring trip count to computation";
   // Transfer number of iterations to Infeed.
   TF_ASSERT_OK(
-      local_client_->TransferToInfeed(*LiteralUtil::CreateR0<int32_t>(1)));
+      local_client_->TransferToInfeed(LiteralUtil::CreateR0<int32_t>(1)));
 
   // Pick up value from outfeed
   {
     VLOG(1) << "Reading from condition outfeed";
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+    TF_ASSERT_OK_AND_ASSIGN(Literal r,
                             local_client_->TransferFromOutfeed(&int_shape));
-    EXPECT_EQ(r->Get<int32>({}), 1);
+    EXPECT_EQ(r.Get<int32>({}), 1);
   }
 
   VLOG(1) << "Writing data to infeed";
   // Transfer some stuff to Infeed for use inside of loop.
   TF_ASSERT_OK(local_client_->TransferToInfeed(
-      *LiteralUtil::CreateR1<int32_t>({10, 20})));
+      LiteralUtil::CreateR1<int32_t>({10, 20})));
 
   // Pick up value from outfeed
   {
     VLOG(1) << "Reading from body outfeed";
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+    TF_ASSERT_OK_AND_ASSIGN(Literal r,
                             local_client_->TransferFromOutfeed(&xfeed_shape));
-    EXPECT_EQ(r->Get<int32>({0}), 11);
-    EXPECT_EQ(r->Get<int32>({1}), 21);
+    EXPECT_EQ(r.Get<int32>({0}), 11);
+    EXPECT_EQ(r.Get<int32>({1}), 21);
   }
 
   {
     VLOG(1) << "Reading from condition outfeed";
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+    TF_ASSERT_OK_AND_ASSIGN(Literal r,
                             local_client_->TransferFromOutfeed(&int_shape));
-    EXPECT_EQ(r->Get<int32>({}), 0);
+    EXPECT_EQ(r.Get<int32>({}), 0);
   }
 
   // Joins the thread
   thread.reset();
 
-  EXPECT_EQ(comp_result->Get<int32>({}), 0);
+  EXPECT_EQ(comp_result.Get<int32>({}), 0);
 }
 
 XLA_TEST_F(OutfeedInNestedComputationTest, OutfeedInConditional) {
@@ -145,7 +145,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(XlaComputation computation, b.Build());
 
-  std::unique_ptr<xla::Literal> comp_result;
+  Literal comp_result;
   std::unique_ptr<tensorflow::Thread> thread(
       tensorflow::Env::Default()->StartThread(
           tensorflow::ThreadOptions(), "execute_thread", [&] {
@@ -154,12 +154,12 @@
           }));
 
   TF_ASSERT_OK(
-      local_client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+      local_client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> r,
+  TF_ASSERT_OK_AND_ASSIGN(Literal r,
                           local_client_->TransferFromOutfeed(&result_shape));
 
-  EXPECT_EQ(r->Get<bool>({}), true);
+  EXPECT_EQ(r.Get<bool>({}), true);
 
   // Join the thread
   thread.reset();
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index cbeddff..6e98167 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@
   dimension->set_edge_padding_high(0);
   dimension->set_interior_padding(0);
 
-  Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
-      AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+  Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+      AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
   ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
 }
 
@@ -108,8 +108,8 @@
   dimension->set_edge_padding_high(4);
   dimension->set_interior_padding(7);
 
-  Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
-      AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+  Pad(AddParam(LiteralUtil::CreateR1<float>({}), &b),
+      AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
   ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
                              DefaultErrorSpec());
 }
@@ -123,8 +123,8 @@
   dimension->set_edge_padding_high(0);
   dimension->set_interior_padding(1);
 
-  Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
-      AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
+  Pad(AddParam(LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+      AddParam(LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
   std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
   ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
 }
@@ -132,7 +132,7 @@
 XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
   XlaBuilder b(TestName());
   Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
-      AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+      AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
       r4_padding_on_dim0_dim1_);
   ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
                              DefaultErrorSpec());
@@ -148,7 +148,7 @@
   });
   input->FillWithYX(input_xy);
 
-  Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+  Pad(AddParam(*input, &b), AddParam(LiteralUtil::CreateR0<float>(1.5), &b),
       r4_padding_on_dim0_dim1_);
 
   auto expected = absl::make_unique<Array4D<float>>(2, 3, 3, 2);
@@ -168,7 +168,7 @@
   const float pad_value = 1.5f;
   Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
   Pad(AddParam(input, &b),
-      AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+      AddParam(LiteralUtil::CreateR0<float>(pad_value), &b),
       r4_padding_on_dim0_dim1_);
 
   auto expected = absl::make_unique<Array4D<float>>(8, 5, 1, 1);
@@ -208,10 +208,10 @@
   const float pad_value = -5.123f;
   Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
-  input = input->Relayout(layout);
+  input = input.Relayout(layout);
 
-  Pad(AddParam(*input, &b),
-      AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+  Pad(AddParam(input, &b),
+      AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
 
   Array4D<float> expected_array(1, 1, 5, 8);
   expected_array.Fill(pad_value);
@@ -254,10 +254,10 @@
   input_array(0, 24, 6, 6) = 2.0f;
   input_array(0, 17, 2, 5) = 3.0f;
   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
-  input = input->Relayout(layout);
+  input = input.Relayout(layout);
 
-  Pad(AddParam(*input, &b),
-      AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
+  Pad(AddParam(input, &b),
+      AddParam(LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
 
   Array4D<float> expected_array(1, 25, 17, 11);
   expected_array.Fill(pad_value);
@@ -331,7 +331,7 @@
     padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
                                                                   100 * dim);
   }
-  Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
+  Pad(input, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
 
   auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
   ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -353,8 +353,7 @@
   padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
   padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
   padding_config.mutable_dimensions(1)->set_interior_padding(2);
-  Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
-      padding_config);
+  Pad(input, AddParam(LiteralUtil::CreateR0<float>(3.14f), &b), padding_config);
 
   auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
   ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -379,7 +378,7 @@
     padding_config.mutable_dimensions(dim)->set_interior_padding(
         interior_padding);
   }
-  Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+  Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
       padding_config);
 
   auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -407,7 +406,7 @@
     padding_config.mutable_dimensions(dim)->set_interior_padding(
         interior_padding);
   }
-  Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+  Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
       padding_config);
 
   auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -435,7 +434,7 @@
     padding_config.mutable_dimensions(dim)->set_interior_padding(
         interior_padding[dim]);
   }
-  Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+  Pad(input, AddParam(LiteralUtil::CreateR0<float>(2.718f), &b),
       padding_config);
 
   auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -452,13 +451,12 @@
 
   XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
   auto reduce =
-      Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
+      Reduce(input, AddParam(LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
 
   PaddingConfig padding_config = MakeNoPaddingConfig(3);
   padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
   padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
-  Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
-      padding_config);
+  Pad(reduce, AddParam(LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
 
   Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
                            {{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index f6c762e..dcb4c11 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -42,10 +42,9 @@
 
 XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
-      LiteralUtil::CreateR0<float>(3.14159f);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(3.14159f);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
 
@@ -55,9 +54,9 @@
 
 XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+  Literal param0_literal = LiteralUtil::CreateR1<float>({});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
 
@@ -67,10 +66,9 @@
 
 XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
-      LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+  Literal param0_literal = LiteralUtil::CreateR1<float>({3.14f, -100.25f});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
 
@@ -81,9 +79,9 @@
 XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
   XlaBuilder builder(TestName());
   string str("hello world");
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+  Literal param0_literal = LiteralUtil::CreateR1U8(str);
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0,
             ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
@@ -94,10 +92,10 @@
 
 XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
 
@@ -107,10 +105,10 @@
 
 XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+  Literal param0_literal = LiteralUtil::CreateR2<float>(
       {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+      client_->TransferToServer(param0_literal).ConsumeValueOrDie();
 
   Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
 
@@ -123,15 +121,15 @@
 XLA_TEST_F(ParamsTest, TwoParameters) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+  Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+  Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
 
   // Use both parameters
   //
@@ -154,9 +152,9 @@
 XLA_TEST_F(ParamsTest, MissingParameter) {
   // Test that an error is returned when a computation with an incomplete set of
   // parameters (parameter numbers not contiguous from 0) is executed.
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+  Literal literal = LiteralUtil::CreateR0<float>(3.14159f);
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
+      client_->TransferToServer(literal).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
   Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
@@ -168,15 +166,15 @@
 XLA_TEST_F(ParamsTest, UnusedParameter) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+  Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
-  Parameter(&builder, 0, literal0->shape(), "param0");
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
+  Parameter(&builder, 0, literal0.shape(), "param0");
 
-  std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+  Literal literal1 = LiteralUtil::CreateR1<float>({10, 20});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
-  Parameter(&builder, 1, literal1->shape(), "param1");
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
+  Parameter(&builder, 1, literal1.shape(), "param1");
 
   ComputeAndCompareR1<float>(&builder, {10, 20},
                              {param0_data.get(), param1_data.get()},
@@ -188,18 +186,17 @@
   // unused expression.
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+  Literal literal0 = LiteralUtil::CreateR1<float>({1, 2});
   std::unique_ptr<GlobalData> param0_data =
-      client_->TransferToServer(*literal0).ConsumeValueOrDie();
+      client_->TransferToServer(literal0).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> literal1 =
-      LiteralUtil::CreateR1<float>({10, 20, 30});
+  Literal literal1 = LiteralUtil::CreateR1<float>({10, 20, 30});
   std::unique_ptr<GlobalData> param1_data =
-      client_->TransferToServer(*literal1).ConsumeValueOrDie();
+      client_->TransferToServer(literal1).ConsumeValueOrDie();
 
-  auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
-  auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
-  auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
+  auto param0 = Parameter(&builder, 0, literal0.shape(), "param0");
+  auto param1 = Parameter(&builder, 1, literal1.shape(), "param1");
+  auto param2 = Parameter(&builder, 2, literal1.shape(), "param2");
 
   // This add is unused.
   Add(param1, param2);
@@ -233,10 +230,10 @@
 
     std::vector<float> sum_value = {{entry0, entry1}};
     sum_value.resize(size);
-    std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+    Literal literal = LiteralUtil::CreateR1<float>(sum_value);
     param_data_owner.push_back(
-        client_->TransferToServer(*literal).ConsumeValueOrDie());
-    XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+        client_->TransferToServer(literal).ConsumeValueOrDie());
+    XlaOp param = Parameter(&builder, i, literal.shape(), "param");
     sum_handle = Add(sum_handle, param);
   }
 
@@ -268,10 +265,10 @@
   constexpr int kParamCount = 3000;
   for (int i = 0; i < kParamCount; ++i) {
     target += i;
-    std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
+    Literal literal = LiteralUtil::CreateR0<float>(i);
     param_data_owner.push_back(
-        std::move(client_->TransferToServer(*literal)).ValueOrDie());
-    XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+        std::move(client_->TransferToServer(literal)).ValueOrDie());
+    XlaOp param = Parameter(&builder, i, literal.shape(), "param");
     sum_handle = Add(sum_handle, param);
   }
 
@@ -300,10 +297,10 @@
   std::vector<XlaOp> params;
   for (int i = 0; i < kParamCount; ++i) {
     target += i;
-    std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+    Literal literal = LiteralUtil::CreateR1<int32>({i, i});
     param_data_owner.push_back(
-        std::move(client_->TransferToServer(*literal)).ValueOrDie());
-    XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+        std::move(client_->TransferToServer(literal)).ValueOrDie());
+    XlaOp param = Parameter(&builder, i, literal.shape(), "param");
     params.push_back(param);
     sum_handle = Add(sum_handle, param);
   }
@@ -321,13 +318,14 @@
     param_data.push_back(data.get());
   }
 
-  std::vector<std::unique_ptr<Literal>> elements;
+  std::vector<Literal> elements;
   std::vector<const Literal*> ptrs;
+  elements.reserve(kParamCount);
   for (int i = 0; i < kParamCount; ++i) {
     elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
-    ptrs.push_back(elements.back().get());
+    ptrs.push_back(&elements.back());
   }
-  ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+  ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
 }
 
 // Test large number of parameters flowing into a while-loop.
@@ -356,23 +354,23 @@
   std::vector<XlaOp> params;
   std::vector<Shape> parameter_shapes;
   for (int i = 0; i < kParamCount; ++i) {
-    std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
+    Literal literal = LiteralUtil::CreateR1<int32>({i, i});
     param_data_owner.push_back(
-        std::move(client_->TransferToServer(*literal)).ValueOrDie());
-    XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+        std::move(client_->TransferToServer(literal)).ValueOrDie());
+    XlaOp param = Parameter(&builder, i, literal.shape(), "param");
     params.push_back(param);
-    parameter_shapes.push_back(literal->shape());
+    parameter_shapes.push_back(literal.shape());
   }
 
   // Add bool parameter for the loop condition. Use a parameter HLO instead of a
   // constant because DCE may eliminate the while-body otherwise.
-  std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
+  Literal bool_literal = LiteralUtil::CreateR0<bool>(false);
   param_data_owner.push_back(
-      std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
+      std::move(client_->TransferToServer(bool_literal)).ValueOrDie());
   XlaOp bool_param =
-      Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
+      Parameter(&builder, kParamCount, bool_literal.shape(), "bool_param");
   params.push_back(bool_param);
-  parameter_shapes.push_back(bool_literal->shape());
+  parameter_shapes.push_back(bool_literal.shape());
 
   auto init = Tuple(&builder, params);
 
@@ -420,13 +418,14 @@
     param_data.push_back(data.get());
   }
 
-  std::vector<std::unique_ptr<Literal>> elements;
+  std::vector<Literal> elements;
   std::vector<const Literal*> ptrs;
+  elements.reserve(kParamCount);
   for (int i = 0; i < kParamCount; ++i) {
     elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
-    ptrs.push_back(elements.back().get());
+    ptrs.push_back(&elements.back());
   }
-  ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
+  ComputeAndCompareTuple(&builder, LiteralUtil::MakeTuple(ptrs), param_data);
 }
 
 #endif
@@ -443,9 +442,9 @@
 
   std::unique_ptr<GlobalData> data =
       client_
-          ->TransferToServer(*LiteralUtil::MakeTuple({
-              LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
-              LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+          ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+              LiteralUtil::CreateR1<float>({1, 2, 3}),
+              LiteralUtil::CreateR1<float>({4, 5, 6}),
           }))
           .ConsumeValueOrDie();
 
@@ -457,34 +456,34 @@
 // Verifies that passing a 2x2 with {0, 1} layout returns the same value back
 // when (transferred to the server and) passed through a parameter.
 XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+  Literal literal = LiteralUtil::CreateR2WithLayout<float>(
       {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
   XlaBuilder builder(TestName());
-  Parameter(&builder, 0, literal->shape(), "input");
+  Parameter(&builder, 0, literal.shape(), "input");
 
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
 }
 
 // As above, but for {1, 0} layout.
 XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+  Literal literal = LiteralUtil::CreateR2WithLayout<float>(
       {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
   XlaBuilder builder(TestName());
-  Parameter(&builder, 0, literal->shape(), "input");
+  Parameter(&builder, 0, literal.shape(), "input");
 
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
-  ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+      client_->TransferToServer(literal).ConsumeValueOrDie();
+  ComputeAndCompareLiteral(&builder, literal, {data.get()}, ErrorSpec(1e-3));
 }
 
 XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+  Literal literal = LiteralUtil::CreateR2<float>({
       {1, 3},
       {2, 4},
   });
-  const Shape original = literal->shape();
+  const Shape original = literal.shape();
   {
     // Reverse the layout present in original, and make that the layout of the
     // literal.
@@ -492,9 +491,9 @@
         original.layout().minor_to_major().begin(),
         original.layout().minor_to_major().end());
     std::reverse(original_layout.begin(), original_layout.end());
-    *literal->mutable_shape_do_not_use()->mutable_layout() =
+    *literal.mutable_shape_do_not_use()->mutable_layout() =
         LayoutUtil::MakeLayout(original_layout);
-    ASSERT_EQ(2, literal->Get<float>({0, 1}));
+    ASSERT_EQ(2, literal.Get<float>({0, 1}));
   }
   // Use the original shape in building the computation.
   XlaBuilder builder(TestName());
@@ -503,7 +502,7 @@
   Slice(input, {0, 1}, {1, 2}, {1, 1});
 
   std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*literal).ConsumeValueOrDie();
+      client_->TransferToServer(literal).ConsumeValueOrDie();
   // Check that we got the off-diagonal value that we expected.
   Array2D<float> expected(1, 1);
   expected(0, 0) = 2;
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 5f322b7..8f2c26f 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -37,8 +37,7 @@
 class PrngTest : public ClientLibraryTestBase {
  protected:
   template <typename T>
-  std::unique_ptr<Literal> UniformTest(T a, T b, absl::Span<const int64> dims,
-                                       int64 seed = 42);
+  Literal UniformTest(T a, T b, absl::Span<const int64> dims, int64 seed = 42);
 
   // Computes the χ² statistic of a sample of the discrete uniform distribution
   // of the given range size. `expected_count` is the number of times each
@@ -49,9 +48,8 @@
 };
 
 template <typename T>
-std::unique_ptr<Literal> PrngTest::UniformTest(T a, T b,
-                                               absl::Span<const int64> dims,
-                                               int64 seed) {
+Literal PrngTest::UniformTest(T a, T b, absl::Span<const int64> dims,
+                              int64 seed) {
   XlaBuilder builder(TestName());
   RngUniform(
       ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
@@ -60,8 +58,8 @@
   SetSeed(seed);
   auto actual =
       ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
-  EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
-  actual->EachCell<T>([=](absl::Span<const int64>, T value) {
+  EXPECT_THAT(dims, ::testing::ElementsAreArray(actual.shape().dimensions()));
+  actual.EachCell<T>([=](absl::Span<const int64>, T value) {
     EXPECT_LE(a, value);
     EXPECT_LT(value, b);
   });
@@ -116,11 +114,10 @@
   constexpr int64 count = 100;
   for (int64 seed = 0; seed < count; ++seed) {
     auto result = UniformTest<bfloat16>(low, high, {}, /*seed=*/seed);
-    result->Literal::EachCell<bfloat16>(
-        [&](absl::Span<const int64>, bfloat16 value) {
-          int64 index = static_cast<int64>((value - low) / interval);
-          counts[index]++;
-        });
+    result.EachCell<bfloat16>([&](absl::Span<const int64>, bfloat16 value) {
+      int64 index = static_cast<int64>((value - low) / interval);
+      counts[index]++;
+    });
   }
   // Each bucket should have similar amount of counts. That is, not more than
   // 10% of total counts. This mostly tests that we don't fall into a 1:2:2
@@ -149,7 +146,7 @@
   auto actual =
       ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
   std::vector<int32> counts(range_size, 0);
-  actual->EachCell<int32>(
+  actual.EachCell<int32>(
       [&counts](absl::Span<const int64>, int32 value) { ++counts[value]; });
   int64 sum = 0;
   for (int32 i = 0; i < range_size; ++i) {
@@ -192,12 +189,12 @@
   };
 
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
-                          client_->TransferToServer(*param0_literal));
+                          client_->TransferToServer(param0_literal));
 
-  auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+  auto param0 = Parameter(&builder, 0, param0_literal.shape(), "param0");
   auto fn = build_sum_rng(builder);
   Map(&builder, {param0}, fn, {0});
 
@@ -210,12 +207,11 @@
                        computation,
                        /*arguments=*/{param0_data.get()}, &execution_options));
 
-  EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()),
-            ShapeUtil::ElementsIn(param0_literal->shape()));
-  for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) {
-    EXPECT_GE(actual->data<float>()[i], param0_literal->data<float>()[i]);
-    EXPECT_LT(actual->data<float>()[i],
-              param0_literal->data<float>()[i] + 1.0f);
+  EXPECT_EQ(ShapeUtil::ElementsIn(actual.shape()),
+            ShapeUtil::ElementsIn(param0_literal.shape()));
+  for (int i = 0; i < ShapeUtil::ElementsIn(actual.shape()); ++i) {
+    EXPECT_GE(actual.data<float>()[i], param0_literal.data<float>()[i]);
+    EXPECT_LT(actual.data<float>()[i], param0_literal.data<float>()[i] + 1.0f);
   }
 }
 
@@ -238,15 +234,15 @@
   ExecutionOptions execution_options2 = execution_options_;
   execution_options2.set_seed(65);
 
-  std::unique_ptr<Literal> result1;
+  Literal result1;
   {
     TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
     TF_ASSERT_OK_AND_ASSIGN(
         result1, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
                                              &execution_options1));
   }
-  std::unique_ptr<Literal> result2;
-  std::unique_ptr<Literal> result3;
+  Literal result2;
+  Literal result3;
   {
     TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
     TF_ASSERT_OK_AND_ASSIGN(
@@ -257,9 +253,9 @@
                                              &execution_options1));
   }
 
-  std::unique_ptr<Literal> result4;
-  std::unique_ptr<Literal> result5;
-  std::unique_ptr<Literal> result6;
+  Literal result4;
+  Literal result5;
+  Literal result6;
   {
     TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
     TF_ASSERT_OK_AND_ASSIGN(
@@ -273,11 +269,11 @@
                                              &execution_options_));
   }
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
-  EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
-  EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
-  EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result1, result2));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result1, result3));
+  EXPECT_FALSE(LiteralTestUtil::Equal(result1, result4));
+  EXPECT_FALSE(LiteralTestUtil::Equal(result4, result5));
+  EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
 }
 
 XLA_TEST_F(PrngTest, TenValuesN01) {
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index 9af9ea4..c9096fb 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -92,7 +92,7 @@
   *reduce_input_shape->mutable_layout() =
       LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
 
-  std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+  Literal reduce_input = LiteralUtil::CreateR4<float>(
       {{ /*i0=0*/
         {/*i1=0*/
          {-0.246092796, -0.179497838, -0.161181688},
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 0916a07..26e2bfd 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -231,11 +231,10 @@
 
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal =
-      LiteralUtil::CreateR1<float>({input_values});
+  Literal a_literal = LiteralUtil::CreateR1<float>({input_values});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   ReducePrecision(a, exponent_bits, mantissa_bits);
 
@@ -255,10 +254,10 @@
            DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+  Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   // Abs doesn't affect resolution.
   auto abs = Abs(a);
@@ -284,10 +283,10 @@
            DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+  Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   // These two operations should be fused by any reasonable backend.
   auto abs = Abs(a);
@@ -310,10 +309,10 @@
            DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+  Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   // These two operations should be fused by any reasonable backend.
   auto abs = Abs(a);
@@ -334,10 +333,10 @@
            DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+  Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   // These two operations should be fused by any reasonable backend.
   auto abs = Abs(a);
@@ -359,10 +358,10 @@
            DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
+  Literal a_literal = LiteralUtil::CreateR1<float>({1.00001});
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
-  auto a = Parameter(&builder, 0, a_literal->shape(), "a");
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
+  auto a = Parameter(&builder, 0, a_literal.shape(), "a");
 
   // These two operations should be fused by any reasonable backend.
   auto abs = Abs(a);
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 8c62ade..83997cd 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -81,9 +81,9 @@
     }, 4);
     // clang-format on
     CHECK(ShapeUtil::Equal(
-        literal_3d_->shape(),
+        literal_3d_.shape(),
         ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
-        << literal_3d_->shape().ShortDebugString();
+        << literal_3d_.shape().ShortDebugString();
   }
 
   // Runs an R1 => R0 reduction test with the given number of elements.
@@ -102,10 +102,9 @@
         input_data[i] *= -1;
       }
     }
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR1(AsSlice(input_data));
+    Literal input_literal = LiteralUtil::CreateR1(AsSlice(input_data));
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     float expected = 0.0;
     for (float item : input_data) {
@@ -134,9 +133,9 @@
     Reduce(pred_values, init_value, reduce,
            /*dimensions_to_reduce=*/{0});
 
-    std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
+    Literal input_literal = LiteralUtil::CreateR1(input_data);
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     bool expected = and_reduce;
     for (bool item : input_data) {
@@ -175,12 +174,11 @@
 
     Array2D<uint8> input_data(rows, cols);
     input_data.FillRandom(0, 1);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR2FromArray2D(input_data);
+    Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
     input_literal =
-        input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+        input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     std::array<bool, cols> expected;
     for (int64 colno = 0; colno < cols; ++colno) {
@@ -209,12 +207,11 @@
 
     Array2D<float> input_data(rows, cols);
     input_data.FillRandom(3.14f, 0.04);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR2FromArray2D(input_data);
+    Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
     input_literal =
-        input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+        input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     float expected = 0.0;
     for (int64 rowno = 0; rowno < rows; ++rowno) {
@@ -237,12 +234,11 @@
 
     Array2D<float> input_data(rows, cols);
     input_data.FillRandom(3.14f, 0.04);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR2FromArray2D(input_data);
+    Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
     input_literal =
-        input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+        input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     std::vector<float> expected;
     for (int64 colno = 0; colno < cols; ++colno) {
@@ -295,12 +291,11 @@
 
     Array2D<NativeT> input_data(rows, cols);
     input_data.FillUnique(initial_value);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR2FromArray2D(input_data);
+    Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
     input_literal =
-        input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
+        input_literal.Relayout(LayoutUtil::MakeLayout({minor, major}));
     std::unique_ptr<GlobalData> input_global_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
     // NativeT can be bool, and std::vector<bool> does not convert to
     // Span.
@@ -352,8 +347,8 @@
         reference_reduction_function_for_uints, unsigned_int_identity);
   }
 
-  std::unique_ptr<Literal> literal_2d_;
-  std::unique_ptr<Literal> literal_3d_;
+  Literal literal_2d_;
+  Literal literal_3d_;
   uint32 seed_ = 0xdeadbeef;
 };
 
@@ -450,11 +445,10 @@
 
   Array2D<float> input_data(rows, cols);
   input_data.FillRandom(3.14f, 0.04);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR2FromArray2D(input_data);
-  input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+  Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+  input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
   std::unique_ptr<GlobalData> input_global_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
   std::vector<float> expected;
   for (int64 colno = 0; colno < cols; ++colno) {
@@ -482,11 +476,10 @@
 
   Array2D<float> input_data(rows, cols);
   input_data.FillRandom(3.14f, 0.04);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR2FromArray2D(input_data);
-  input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
+  Literal input_literal = LiteralUtil::CreateR2FromArray2D(input_data);
+  input_literal = input_literal.Relayout(LayoutUtil::MakeLayout({0, 1}));
   std::unique_ptr<GlobalData> input_global_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
   std::vector<float> expected;
   for (int64 colno = 0; colno < cols; ++colno) {
@@ -511,10 +504,9 @@
   XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
   Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
-                          MakeFakeLiteral(input_shape));
+  TF_ASSERT_OK_AND_ASSIGN(Literal input_data, MakeFakeLiteral(input_shape));
 
-  ComputeAndCompare(&builder, {std::move(*input_data)}, ErrorSpec(0.01, 1e-4));
+  ComputeAndCompare(&builder, {std::move(input_data)}, ErrorSpec(0.01, 1e-4));
 }
 
 XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
@@ -531,10 +523,9 @@
 
   Array3D<float> input_data(rows, 2, cols / 2);
   input_data.FillRandom(3.14f, 0.04);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR3FromArray3D(input_data);
+  Literal input_literal = LiteralUtil::CreateR3FromArray3D(input_data);
   std::unique_ptr<GlobalData> input_global_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
   std::vector<float> expected;
   for (int64 major = 0; major < 2; ++major) {
@@ -595,7 +586,7 @@
   Array2D<float> input(300, 250);
   input.FillRandom(214.0f);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
-  Reduce(ConstantLiteral(&builder, *input_literal),
+  Reduce(ConstantLiteral(&builder, input_literal),
          ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
   auto input_max = FLT_MIN;
   input.Each(
@@ -610,7 +601,7 @@
   Array2D<float> input(150, 130);
   input.FillRandom(214.0f);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
-  Reduce(ConstantLiteral(&builder, *input_literal),
+  Reduce(ConstantLiteral(&builder, input_literal),
          ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
 
   auto input_min = FLT_MAX;
@@ -627,7 +618,7 @@
   auto initial_value =
       ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
 
-  Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
+  Reduce(ConstantLiteral(&builder, input_literal), initial_value, min, {0, 1});
   ComputeAndCompareR0<uint32>(&builder, 1, {});
 }
 
@@ -639,14 +630,14 @@
   auto initial_value =
       ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
 
-  Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
+  Reduce(ConstantLiteral(&builder, input_literal), initial_value, max, {0, 1});
   ComputeAndCompareR0<uint32>(&builder, 2, {});
 }
 
 // Reduces a matrix among dimension 1.
 XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_2d_);
+  auto m = ConstantLiteral(&builder, literal_2d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
 
@@ -657,7 +648,7 @@
 XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
   // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_2d_);
+  auto m = ConstantLiteral(&builder, literal_2d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
 
@@ -667,7 +658,7 @@
 // Tests 2D matrix ReduceToRow operation.
 XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
   XlaBuilder builder("reduce_among_y");
-  auto m = ConstantLiteral(&builder, *literal_2d_);
+  auto m = ConstantLiteral(&builder, literal_2d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
 
@@ -677,7 +668,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
 
@@ -687,7 +678,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
 
@@ -697,7 +688,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
 
@@ -707,7 +698,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
 
@@ -722,7 +713,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
 
@@ -739,7 +730,7 @@
 
 XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
   XlaBuilder builder(TestName());
-  auto m = ConstantLiteral(&builder, *literal_3d_);
+  auto m = ConstantLiteral(&builder, literal_3d_);
   auto add = CreateScalarAddComputation(F32, &builder);
   Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
 
@@ -824,12 +815,12 @@
 
   auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
   input_literal =
-      input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
+      input_literal.Relayout(LayoutUtil::MakeLayout(GetParam().layout));
   std::unique_ptr<GlobalData> input_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
 
   auto input_activations =
-      Parameter(&builder, 0, input_literal->shape(), "input");
+      Parameter(&builder, 0, input_literal.shape(), "input");
   XlaComputation add = CreateScalarAddComputation(F32, &builder);
   Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
          GetParam().reduce_dims);
@@ -866,21 +857,17 @@
                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
                       BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
 
-// TODO(b/64093391) Disabled on GPU due to an assertion failure when running
-// IrEmitterUnnested::EmitInitializer() for the Reduce operator.  Failed on
-// 2017-07-26.
-XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
+XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
   XlaBuilder builder(TestName());
   XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
 
   auto a = ConstantR0<float>(&builder, 2.0f);
   auto a2 = Abs(a);
 
-  std::unique_ptr<Literal> b_literal =
-      LiteralUtil::CreateR1<float>({1.0f, 4.0f});
+  Literal b_literal = LiteralUtil::CreateR1<float>({1.0f, 4.0f});
   std::unique_ptr<GlobalData> b_data =
-      client_->TransferToServer(*b_literal).ConsumeValueOrDie();
-  auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+      client_->TransferToServer(b_literal).ConsumeValueOrDie();
+  auto b = Parameter(&builder, 0, b_literal.shape(), "b");
   Reduce(b, a2, max_f32, {0});
 
   ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
@@ -907,9 +894,9 @@
     std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
     auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
     auto input_data =
-        client_->TransferToServer(*input_literal).ConsumeValueOrDie();
-    Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
-           max_fn, {0});
+        client_->TransferToServer(input_literal).ConsumeValueOrDie();
+    Reduce(Parameter(&builder, 0, input_literal.shape(), "input"), init, max_fn,
+           {0});
 
     ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
   }
@@ -955,13 +942,12 @@
   float operand[] = {42.0f};
   float init = 58.5f;
   float expected = 42.0f;
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR1<float>(operand);
+  Literal input_literal = LiteralUtil::CreateR1<float>(operand);
   std::unique_ptr<GlobalData> input_global_data =
-      client_->TransferToServer(*input_literal).ConsumeValueOrDie();
-  std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
+      client_->TransferToServer(input_literal).ConsumeValueOrDie();
+  Literal input_literal2 = LiteralUtil::CreateR0<float>(init);
   std::unique_ptr<GlobalData> input_global_data2 =
-      client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
+      client_->TransferToServer(input_literal2).ConsumeValueOrDie();
   ComputeAndCompareR0<float>(
       &builder, expected, {input_global_data.get(), input_global_data2.get()},
       ErrorSpec(0.0001));
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 997880a..63491a9 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -73,7 +73,7 @@
                        absl::Span<const int64> window_dimensions,
                        absl::Span<const int64> window_strides,
                        Padding padding) {
-    auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+    auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
                                           &builder_);
     ReduceWindow(input, init,
                  CreateScalarAddComputation(FloatType(), &builder_),
@@ -107,9 +107,9 @@
 
 TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
+      LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
   const auto init_value =
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
+      CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
   TF_ASSERT_OK(builder_.first_error());
   ReduceWindow(input, init_value,
                CreateScalarAddComputation(FloatType(), &builder_),
@@ -124,31 +124,31 @@
 // Regression test for b/68964348.
 TEST_P(ReduceWindowTest, R0ReduceWindow) {
   const auto input =
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
+      CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
   const auto init =
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+      CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(1.0), &builder_);
   ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
                /*window_dimensions=*/{},
                /*window_strides=*/{}, Padding::kSame);
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(43.0), {},
                            ErrorSpec(0.00001));
 }
 
 TEST_P(ReduceWindowTest, Min3In5Stride2) {
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+      LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
                            {}, ErrorSpec(0.00001));
 }
 
 TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+      LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
                   Padding::kSame);
   ComputeAndCompareLiteral(&builder_,
-                           *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+                           LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
                            {}, ErrorSpec(0.00001));
 }
 
@@ -161,7 +161,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
                                               {1, 1, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
                            DefaultErrorSpec());
 }
 
@@ -176,7 +176,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
                                               {1, 1, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
                            DefaultErrorSpec());
 }
 
@@ -190,7 +190,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
                                               {1, 2, 2, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
                            DefaultErrorSpec());
 }
 
@@ -207,7 +207,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
                            DefaultErrorSpec());
 }
 
@@ -229,8 +229,8 @@
       input_array, 0.0f, {win_len, win_len, 1, 1},
       {win_stride, win_stride, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
-                           {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+                           DefaultErrorSpec());
 }
 
 TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -252,8 +252,8 @@
       input_array, 0.0f, {win_len, win_len, 1, 1},
       {win_stride, win_stride, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
-                           {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+                           DefaultErrorSpec());
 }
 
 // Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -277,8 +277,8 @@
       input_array, 0.0f, {win_len, win_len, 1, 1},
       {win_stride, win_stride, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
-                           {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+                           DefaultErrorSpec());
 }
 
 TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -294,8 +294,8 @@
   auto result = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
-                           {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+                           DefaultErrorSpec());
 }
 
 // Tests a reduction function that is not a simple add/min/max/etc.
@@ -313,12 +313,12 @@
   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
   Min(Add(lhs, rhs),
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
+      CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
   XlaComputation reduce_fn = b->BuildAndNoteError();
 
   ReduceWindow(
       input,
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
+      CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
       reduce_fn,
       /*window_dimensions=*/{1, 1, 2, 1},
       /*window_strides=*/{1, 1, 1, 1}, padding);
@@ -332,19 +332,18 @@
                                            /*window=*/{1, 1, 2, 1},
                                            /*stride=*/{1, 1, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
                            {}, DefaultErrorSpec());
 }
 
 TEST_P(ReduceWindowTest, R4UnitWindow) {
   Array4D<float> input_array(13, 12, 8, 15);
   input_array.FillRandom(2.f, 2.f);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
   XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "parameter", &builder_, &input);
+      0, input_literal, "parameter", &builder_, &input);
 
   Padding padding = Padding::kSame;
   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
@@ -352,7 +351,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
                                               {1, 4, 1, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
                            {input_data.get()}, DefaultErrorSpec());
 }
 
@@ -360,9 +359,9 @@
   std::vector<int64> input_dims(6, 8);
   auto shape = ShapeUtil::MakeShape(F32, input_dims);
 
-  auto arg_literal = absl::make_unique<Literal>(shape);
-  arg_literal->PopulateWithValue(1.0f);
-  const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+  Literal arg_literal(shape);
+  arg_literal.PopulateWithValue(1.0f);
+  const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
 
   Padding padding = Padding::kValid;
   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
@@ -371,39 +370,38 @@
   std::vector<int64> output_dims = {6, 8, 6, 6, 8, 8};
   Shape result_shape =
       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
-  auto expected = absl::make_unique<Literal>(result_shape);
-  expected->PopulateWithValue(27.0f);
-  ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+  Literal expected(result_shape);
+  expected.PopulateWithValue(27.0f);
+  ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
 }
 
 XLA_TEST_P(ReduceWindowTest, R6Add) {
   std::vector<int64> input_dims(6, 8);
   auto shape = ShapeUtil::MakeShape(F32, input_dims);
 
-  std::unique_ptr<Literal> arg_literal =
+  Literal arg_literal =
       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
 
-  const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
+  const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
 
   Padding padding = Padding::kValid;
   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
 
   std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
-  std::unique_ptr<Literal> expected =
+  Literal expected =
       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
 
-  ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
 }
 
 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
   Array4D<float> input_array(2, 1, 27, 119);
   input_array.FillRandom(2.0f);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "parameter", &builder_, &input);
+      0, input_literal, "parameter", &builder_, &input);
 
   int win_len = 1;
   int stride = 8;
@@ -413,19 +411,18 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
                            {input_data.get()}, DefaultErrorSpec());
 }
 
 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
   Array4D<float> input_array(3, 2, 4, 64);
   input_array.FillRandom(2.0f);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "parameter", &builder_, &input);
+      0, input_literal, "parameter", &builder_, &input);
 
   int win_len = 3;
   int stride = 1;
@@ -435,19 +432,18 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
                            {input_data.get()}, DefaultErrorSpec());
 }
 
 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
   Array4D<float> input_array(1, 3, 12, 200);
   input_array.FillRandom(2.0f);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp input;
   auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "parameter", &builder_, &input);
+      0, input_literal, "parameter", &builder_, &input);
 
   int win_len = 8;
   int stride = 5;
@@ -457,7 +453,7 @@
   auto res = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
 
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
                            {input_data.get()}, DefaultErrorSpec());
 }
 
@@ -478,18 +474,18 @@
   auto result = ReferenceUtil::ReduceWindow4DAdd(
       input_array, 0.0f, {win_len, win_len, 1, 1},
       {win_stride, win_stride, 1, 1}, padding);
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
-                           {}, DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
+                           DefaultErrorSpec());
 }
 
 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
   std::vector<float> input_vector(128 * 9, 1);
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+      LiteralUtil::CreateR1<float>(input_vector), &builder_);
   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
   ComputeAndCompareLiteral(
       &builder_,
-      *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+      LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
       DefaultErrorSpec());
 }
 
@@ -504,9 +500,9 @@
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+      LiteralUtil::CreateR1<float>(input_vector), &builder_);
   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
                            DefaultErrorSpec());
 }
 
@@ -521,9 +517,9 @@
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
   const auto input = CreateConstantFromLiteral(
-      *LiteralUtil::CreateR1<float>(input_vector), &builder_);
+      LiteralUtil::CreateR1<float>(input_vector), &builder_);
   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
-  ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
                            DefaultErrorSpec());
 }
 
@@ -540,9 +536,8 @@
   auto res = ReferenceUtil::ReduceWindow2DAdd(
       input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
 
-  ComputeAndCompareLiteral(&builder_,
-                           *LiteralUtil::CreateFromArray<float>(*res), {},
-                           DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+                           {}, DefaultErrorSpec());
 }
 
 TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
@@ -556,9 +551,8 @@
   auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
                                               padding);
 
-  ComputeAndCompareLiteral(&builder_,
-                           *LiteralUtil::CreateFromArray<float>(*res), {},
-                           DefaultErrorSpec());
+  ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray<float>(*res),
+                           {}, DefaultErrorSpec());
 }
 
 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -594,7 +588,7 @@
   // Test names are not allowed to contain the '-' character.
   std::replace(str.begin(), str.end(), '-', 'n');
   if (::testing::get<1>(data.param)) {
-    str = absl::StrCat(str, "_bfloat16");
+    absl::StrAppend(&str, "_bfloat16");
   }
   return str;
 }
@@ -613,12 +607,11 @@
 
     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
                          param.base_bounds[2], param.base_bounds[3]);
-    input.FillIota(1);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR4FromArray4DWithLayout(
-            input, LayoutUtil::MakeLayout(param.layout));
+    input.FillRandom(0.1f, 0.1f);
+    Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+        input, LayoutUtil::MakeLayout(param.layout));
     XlaOp parameter;
-    auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+    auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
                                                        &b, &parameter);
 
     std::vector<std::pair<int64, int64>> padding(4);
@@ -627,9 +620,16 @@
     }
 
     auto init_value =
-        CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+        CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
     CHECK(param.reducer == kAdd || param.reducer == kMax);
-    auto computation = param.reducer == kAdd
+    auto reducer = param.reducer;
+    if (use_bfloat16() && Product(param.window_bounds) > 128) {
+      // To avoid numerical issues, force the reducer to be kMax for large bf16
+      // windows.
+      reducer = kMax;
+    }
+
+    auto computation = reducer == kAdd
                            ? CreateScalarAddComputation(FloatType(), &b)
                            : CreateScalarMaxComputation(FloatType(), &b);
     ReduceWindowWithGeneralPadding(
@@ -640,8 +640,8 @@
         /*window_strides=*/param.strides,
         /*padding=*/padding);
 
-    CHECK(param.reducer == kAdd || param.reducer == kMax);
-    auto reduce_func = param.reducer == kAdd
+    CHECK(reducer == kAdd || reducer == kMax);
+    auto reduce_func = reducer == kAdd
                            ? +[](float a, float b) { return a + b; }
                            : +[](float a, float b) { return std::max(a, b); };
     std::unique_ptr<Array4D<float>> expected =
@@ -652,12 +652,11 @@
             /*window=*/param.window_bounds,
             /*stride=*/param.strides,
             /*padding=*/padding);
-    std::unique_ptr<Literal> expected_literal =
-        LiteralUtil::CreateFromArray(*expected);
+    Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
-        input_literal->shape().element_type(),
-        AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
-    ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()},
+        input_literal.shape().element_type(),
+        AsInt64Slice(expected_literal.shape().dimensions()), param.layout);
+    ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
                              DefaultErrorSpec(), &expected_shape_with_layout);
   }
 };
@@ -809,6 +808,22 @@
                            /*pad_high=*/{1, 0, 0, 0},
                            /*layout=*/{3, 2, 1, 0},
                            /*reducer=*/kAdd},
+
+    R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3},
+                           /*window_bounds=*/{1, 64, 64, 1},
+                           /*strides=*/{1, 64, 64, 1},
+                           /*pad_low=*/{0, 0, 0, 0},
+                           /*pad_high=*/{0, 0, 0, 0},
+                           /*layout=*/{3, 0, 2, 1},
+                           /*reducer=*/kAdd},
+
+    R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
+                           /*window_bounds=*/{112, 112, 1, 8},
+                           /*strides=*/{112, 112, 1, 8},
+                           /*pad_low=*/{0, 0, 0, 0},
+                           /*pad_high=*/{0, 0, 0, 0},
+                           /*layout=*/{3, 2, 1, 0},
+                           /*reducer=*/kAdd},
 };
 
 INSTANTIATE_TEST_CASE_P(
@@ -930,6 +945,27 @@
     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
+     /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+    {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
+     /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
+     /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
 };
 
 string R3ReduceWindowTestDataToString(
@@ -944,7 +980,7 @@
       param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
       param.reducer == kAdd ? "add" : "max");
   if (::testing::get<1>(data.param)) {
-    str = absl::StrCat(str, "_bfloat16");
+    absl::StrAppend(&str, "_bfloat16");
   }
   return str;
 }
@@ -956,35 +992,41 @@
   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
 };
 
-TEST_P(R3ReduceWindowTest, Add) {
+TEST_P(R3ReduceWindowTest, DoIt) {
   XlaBuilder b(TestName());
   const auto& param = ::testing::get<0>(GetParam());
-  CHECK(param.reducer == kAdd);
 
   const float kInitValue = 0.0f;
   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
-                       param.base_bounds[2], 1.0f);
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR3FromArray3DWithLayout(
-          input, LayoutUtil::MakeLayout(param.layout));
+                       param.base_bounds[2]);
+  input.FillRandom(0.1f, 0.1f);
+  Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
+      input, LayoutUtil::MakeLayout(param.layout));
+  auto reducer = param.reducer;
+  if (use_bfloat16()) {
+    input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
+    if (Product(param.window_bounds) > 128) {
+      // To avoid numerical issues, force the reducer to be kMax for large bf16
+      // windows.
+      reducer = kMax;
+    }
+  }
 
-  XlaOp parameter;
-  auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
-                                                     &b, &parameter);
+  XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
   auto init_value =
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+      CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
+
+  auto computation = reducer == kAdd
+                         ? CreateScalarAddComputation(FloatType(), &b)
+                         : CreateScalarMaxComputation(FloatType(), &b);
+
   ReduceWindow(/*operand=*/parameter,
                /*init_value=*/init_value,
-               /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+               /*computation=*/computation,
                /*window_dimensions=*/param.window_bounds,
                /*window_strides=*/param.strides, /*padding=*/param.padding);
 
-  auto expected = ReferenceUtil::ReduceWindow3DAdd(
-      /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
-      /*stride=*/param.strides, /*padding=*/param.padding);
-
-  ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
-                           {input_arg.get()}, DefaultErrorSpec());
+  ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
 }
 
 INSTANTIATE_TEST_CASE_P(
@@ -1079,7 +1121,7 @@
       param.layout[1],  //
       "__reducer_", param.reducer == kAdd ? "add" : "max");
   if (::testing::get<1>(data.param)) {
-    str = absl::StrCat(str, "_bfloat16");
+    absl::StrAppend(&str, "_bfloat16");
   }
   return str;
 }
@@ -1093,16 +1135,14 @@
   void DoIt() {
     XlaBuilder b(TestName());
     const auto& param = ::testing::get<0>(GetParam());
-    CHECK(param.reducer == kAdd);
 
     const float kInitValue = 0.0f;
     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
-    std::unique_ptr<Literal> input_literal =
-        LiteralUtil::CreateR2FromArray2DWithLayout(
-            input, LayoutUtil::MakeLayout(param.layout));
+    Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
+        input, LayoutUtil::MakeLayout(param.layout));
 
     XlaOp parameter;
-    auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+    auto input_arg = CreateParameterAndTransferLiteral(0, input_literal, "p0",
                                                        &b, &parameter);
     std::vector<std::pair<int64, int64>> padding(2);
     for (int i = 0; i < 2; ++i) {
@@ -1112,7 +1152,7 @@
                            ? CreateScalarAddComputation(FloatType(), &b)
                            : CreateScalarMaxComputation(FloatType(), &b);
     auto init_value =
-        CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+        CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
     ReduceWindowWithGeneralPadding(
         /*operand=*/parameter,
         /*init_value=*/init_value,
@@ -1128,7 +1168,7 @@
         /*window=*/param.window_bounds,
         /*stride=*/param.strides, /*padding=*/padding);
 
-    ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
+    ComputeAndCompareLiteral(&b, LiteralUtil::CreateFromArray(*expected),
                              {input_arg.get()}, DefaultErrorSpec());
   }
 };
@@ -1282,7 +1322,7 @@
                    "__pad_high_", absl::StrJoin(param.pad_high, "x"),
                    "__reducer_", param.reducer == kAdd ? "add" : "max");
   if (::testing::get<1>(data.param)) {
-    str = absl::StrCat(str, "_bfloat16");
+    absl::StrAppend(&str, "_bfloat16");
   }
   return str;
 }
@@ -1302,11 +1342,11 @@
   const float kInitValue = 0.0f;
   std::vector<float> input_vector(param.base_bounds[0]);
   std::iota(std::begin(input_vector), std::end(input_vector), 0);
-  std::unique_ptr<Literal> input_literal =
+  Literal input_literal =
       LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
   XlaOp parameter;
-  auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
-                                                     &b, &parameter);
+  auto input_arg =
+      CreateParameterAndTransferLiteral(0, input_literal, "p0", &b, &parameter);
 
   std::vector<std::pair<int64, int64>> padding(1);
   padding[0] = {param.pad_low[0], param.pad_high[0]};
@@ -1315,7 +1355,7 @@
                          ? CreateScalarAddComputation(FloatType(), &b)
                          : CreateScalarMaxComputation(FloatType(), &b);
   auto init_value =
-      CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+      CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
   ReduceWindowWithGeneralPadding(
       /*operand=*/parameter,
       /*init_value=*/init_value,
@@ -1334,7 +1374,7 @@
       /*stride=*/param.strides,
       /*padding=*/padding);
 
-  ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
+  ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
                            {input_arg.get()}, DefaultErrorSpec());
 }
 
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index d891451..5cf87e5 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -58,13 +58,13 @@
   ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
 
   // Run it.
-  std::unique_ptr<Literal> literal =
+  Literal literal =
       client_
           ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
           .ConsumeValueOrDie();
 
   // Expect 4.
-  LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+  LiteralTestUtil::ExpectR0Equal<int32>(4, literal);
 }
 
 XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
@@ -91,12 +91,12 @@
 
   // Run it.
   std::unique_ptr<GlobalData> x_data =
-      client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+      client_->TransferToServer(LiteralUtil::CreateR0<int32>(2))
           .ConsumeValueOrDie();
   std::unique_ptr<GlobalData> y_data =
-      client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+      client_->TransferToServer(LiteralUtil::CreateR0<int32>(3))
           .ConsumeValueOrDie();
-  std::unique_ptr<Literal> literal =
+  Literal literal =
       client_
           ->ExecuteAndTransfer(replayed,
                                /*arguments=*/{x_data.get(), y_data.get()},
@@ -104,7 +104,7 @@
           .ConsumeValueOrDie();
 
   // Expect 5.
-  LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+  LiteralTestUtil::ExpectR0Equal<int32>(5, literal);
 }
 
 TEST_F(ReplayTest, MapPlusTwoOverR1) {
@@ -136,13 +136,13 @@
   ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
 
   // Run it.
-  std::unique_ptr<Literal> literal =
+  Literal literal =
       client_
           ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_)
           .ConsumeValueOrDie();
 
   // Expect result.
-  LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+  LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, literal);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 17d1271..dedc95b 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -57,12 +57,12 @@
   input_array.Fill(1.0f);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
 
   auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -70,12 +70,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{});
 
   auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -83,12 +83,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0});
 
   auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -99,29 +99,29 @@
   input_array.Fill(1.0f);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "parameter",
                                                  &builder, &parameter);
   auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
                          /*new_sizes=*/{});
   auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
 
   auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
 XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
+  Literal param0_literal = LiteralUtil::CreateR0<float>(1.0f);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+  auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
                                                  &builder, &parameter);
   auto a = Neg(parameter);
   Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
 
   auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -130,25 +130,25 @@
   Array2D<float> input_array(0, 3);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
   auto expected_literal = LiteralUtil::CreateR1<float>({});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
 XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
   XlaBuilder builder(TestName());
 
-  std::unique_ptr<Literal> param0_literal =
+  Literal param0_literal =
       LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
+  auto input = CreateParameterAndTransferLiteral(0, param0_literal, "param0",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
   auto expected_literal = LiteralUtil::CreateR1<float>({});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -157,11 +157,11 @@
   Array2D<float> input_array(3, 0);
   auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
   auto expected_literal = LiteralUtil::CreateR1<float>({});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -170,11 +170,11 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
   auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -183,11 +183,11 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
   auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -196,12 +196,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0},
           /*new_sizes=*/{2, 0});
   auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -211,13 +211,13 @@
   auto input_literal =
       LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0},
           /*new_sizes=*/{2, 3});
   auto expected_literal =
       LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -226,12 +226,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
           /*new_sizes=*/{2, 0});
   auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -241,14 +241,14 @@
   auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
   auto input_literal = LiteralUtil::CreateFromArray(*simple);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
           /*new_sizes=*/{3, 1});
 
   auto expected = ReferenceUtil::TransposeArray2D(*simple);
   auto expected_literal = LiteralUtil::CreateFromArray(*expected);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -258,14 +258,14 @@
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
           /*new_sizes=*/{3, 4});
 
   auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
   auto expected_literal = LiteralUtil::CreateFromArray(*expected);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -274,11 +274,11 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Transpose(parameter, {1, 0});
   auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -288,13 +288,13 @@
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Transpose(parameter, {1, 0});
 
   auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
   auto expected_literal = LiteralUtil::CreateFromArray(*expected);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -304,13 +304,13 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
           /*new_sizes=*/{2, 3, 0, 0});
   auto expected_literal =
       LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -318,12 +318,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
           /*new_sizes=*/{24, 0});
   auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -334,14 +334,14 @@
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
           /*new_sizes=*/{2, 6});
 
   auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
   auto expected_literal = LiteralUtil::CreateFromArray(*expected);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -349,12 +349,12 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
           /*new_sizes=*/{3, 0});
   auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -365,14 +365,14 @@
   auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
   auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
           /*new_sizes=*/{2, 6});
   Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
                            {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
   auto expected_literal = LiteralUtil::CreateFromArray(expected);
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -391,14 +391,14 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
           /*new_sizes=*/{24});
   auto expected_literal = LiteralUtil::CreateR1<float>(
       {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
        30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -406,7 +406,7 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
           /*new_sizes=*/{8, 3});
@@ -418,7 +418,7 @@
                                                         {35, 36, 37},
                                                         {40, 41, 42},
                                                         {45, 46, 47}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -426,14 +426,14 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
           /*new_sizes=*/{24});
   auto expected_literal = LiteralUtil::CreateR1<float>(
       {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
        15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -441,7 +441,7 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
           /*new_sizes=*/{8, 3});
@@ -453,7 +453,7 @@
                                                         {45, 16, 26},
                                                         {36, 46, 17},
                                                         {27, 37, 47}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -461,14 +461,14 @@
   XlaBuilder builder(TestName());
   auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
           /*new_sizes=*/{2, 6, 2});
   auto expected_literal = LiteralUtil::CreateR3<float>(
       {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
        {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -494,14 +494,14 @@
   t2x2x2x3.FillWithYX(*filler2x3);
   auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
   auto expected_literal = LiteralUtil::CreateR2<float>(
       {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
        {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
         6.0f}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -519,14 +519,14 @@
   t(1, 0, 1, 1) = 7;
   auto input_literal = LiteralUtil::CreateFromArray(t);
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
           /*new_sizes=*/{2, 4});
 
   auto expected_literal =
       LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -547,7 +547,7 @@
     Reshape(parameter, dimensions, {});
 
     auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
-    ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
+    ComputeAndCompareLiteral(&b, expected_literal, {input.get()},
                              zero_error_spec_);
   }
 }
@@ -556,7 +556,7 @@
   XlaBuilder b(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
                                                  &parameter);
   Reshape(parameter, {}, {});
   EXPECT_THAT(
@@ -568,7 +568,7 @@
   XlaBuilder b(TestName());
   auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", &b,
                                                  &parameter);
   Reshape(parameter, {1}, {});
   EXPECT_THAT(ExecuteToString(&b, {}),
@@ -604,7 +604,7 @@
        LayoutUtil::MakeLayout({0, 1, 2, 3}));
   // clang-format on
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
 
   Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
@@ -619,27 +619,26 @@
   *execution_options.mutable_shape_with_output_layout() =
       ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
                                      {1, 0});
-  std::unique_ptr<Literal> actual =
+  Literal actual =
       client_
           ->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
           .ConsumeValueOrDie();
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+  Literal expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
   if (use_bfloat16()) {
-    expected = LiteralUtil::ConvertF32ToBF16(*expected);
+    expected = LiteralUtil::ConvertF32ToBF16(expected);
   }
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual));
 }
 
 XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+  Literal input_literal = LiteralUtil::CreateR2<float>({
       {0, 1, 2, 3, 4, 5, 6, 7},
       {100, 101, 102, 103, 104, 105, 106, 107},
       {200, 201, 202, 203, 204, 205, 206, 207},
   });
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
 
@@ -653,20 +652,20 @@
      {{204, 205, 206, 207}}}
   });
   // clang-format on
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
 // Tests R2->R4 reshape with the reshape dimensions {1, 0}.
 XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
+  Literal input_literal = LiteralUtil::CreateR2<float>({
       {0, 1, 2, 3, 4, 5, 6, 7},
       {100, 101, 102, 103, 104, 105, 106, 107},
       {200, 201, 202, 203, 204, 205, 206, 207},
   });
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
+  auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
                                                  &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
 
@@ -680,7 +679,7 @@
      {{206, 7, 107, 207}}}
   });
   // clang-format on
-  ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
+  ComputeAndCompareLiteral(&builder, expected_literal, {input.get()},
                            zero_error_spec_);
 }
 
@@ -691,17 +690,15 @@
   Array4D<float> input(2, 1, 1, 1);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+  Literal expected = LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, input_literal);
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
                            zero_error_spec_);
 }
 
@@ -712,17 +709,15 @@
   Array4D<float> input(2, 1, 4, 1);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+  Literal expected = LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, input_literal);
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
                            zero_error_spec_);
 }
 
@@ -734,12 +729,11 @@
   Array4D<float> input(5, 10, 2, 3);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
           /*new_sizes=*/{5, 60});
 
@@ -749,7 +743,7 @@
         *cell;
   });
   auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
                            zero_error_spec_);
 }
 
@@ -761,12 +755,11 @@
   input_array.Each(
       [&rng, &distribution](absl::Span<const int64> /* indices */,
                             float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
           /*new_sizes=*/{7, 2, 3, 5});
   XlaComputation computation = builder.Build().ConsumeValueOrDie();
@@ -775,7 +768,7 @@
   *execution_options.mutable_shape_with_output_layout() =
       ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
                                      {2, 3, 0, 1});
-  std::unique_ptr<Literal> output_literal =
+  Literal output_literal =
       client_
           ->ExecuteAndTransfer(computation, {input_data.get()},
                                &execution_options)
@@ -784,10 +777,10 @@
   // Since the reshape is a no-op, verify that it does not change the underlying
   // data.
   if (use_bfloat16()) {
-    auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
-    EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
+    auto expected = LiteralUtil::ConvertF32ToBF16(input_literal);
+    EXPECT_EQ(expected.data<bfloat16>(), output_literal.data<bfloat16>());
   } else {
-    EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
+    EXPECT_EQ(input_literal.data<float>(), output_literal.data<float>());
   }
 }
 
@@ -798,12 +791,12 @@
         {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
 
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+  auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
                                                  &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
           /*new_sizes=*/{1, 2, 3, 4});
 
-  ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
+  ComputeAndCompareLiteral(&builder, literal_1x2x3x4, {input.get()});
 }
 
 XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
@@ -813,7 +806,7 @@
 
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
+  auto input = CreateParameterAndTransferLiteral(0, literal_1x2x3x4, "input",
                                                  &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
           /*new_sizes=*/{2, 4, 3, 1});
@@ -830,7 +823,7 @@
         {{16}, {20}, {24}}}});
   // clang-format on
 
-  ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()});
+  ComputeAndCompareLiteral(&builder, expected_2x4x3x1, {input.get()});
 }
 
 XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
@@ -841,24 +834,23 @@
   Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
           /*new_sizes=*/new_bounds);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
-          ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal expected =
+      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+          .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
   // actually corresponds to a two minor transpose.
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
-                           zero_error_spec_, &expected->shape());
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+                           zero_error_spec_, &expected.shape());
 }
 
 XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
@@ -869,24 +861,23 @@
   Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
           /*new_sizes=*/new_bounds);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
-          ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal expected =
+      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+          .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
   // actually corresponds to a two minor transpose.
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
-                           zero_error_spec_, &expected->shape());
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+                           zero_error_spec_, &expected.shape());
 }
 
 XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
@@ -897,24 +888,23 @@
   Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
           /*new_sizes=*/new_bounds);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
-          ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal expected =
+      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+          .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
   // actually corresponds to a two minor transpose.
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
-                           zero_error_spec_, &expected->shape());
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+                           zero_error_spec_, &expected.shape());
 }
 
 XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
@@ -926,24 +916,23 @@
   Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
           /*new_sizes=*/new_bounds);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
-          ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
+  Literal expected =
+      LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, input_literal)
+          .Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
   // actually corresponds to a two minor transpose.
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
-                           zero_error_spec_, &expected->shape());
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+                           zero_error_spec_, &expected.shape());
 }
 
 XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
@@ -954,24 +943,23 @@
   Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
   input.Each([&rng, &distribution](absl::Span<const int64> /* indices */,
                                    float* cell) { *cell = distribution(rng); });
-  std::unique_ptr<Literal> input_literal =
-      LiteralUtil::CreateR4FromArray4DWithLayout(
-          input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+  Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+      input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
   XlaBuilder builder(TestName());
   XlaOp parameter;
-  auto input_data = CreateParameterAndTransferLiteral(
-      0, *input_literal, "input", &builder, &parameter);
+  auto input_data = CreateParameterAndTransferLiteral(0, input_literal, "input",
+                                                      &builder, &parameter);
   Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
           /*new_sizes=*/new_bounds);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
-          ->Relayout(input_literal->shape().layout());
+  Literal expected =
+      LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, input_literal)
+          .Relayout(input_literal.shape().layout());
 
   // Specify the requested output shape explicitly to ensure that this reshape
   // actually corresponds to a two minor transpose.
-  ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
-                           zero_error_spec_, &expected->shape());
+  ComputeAndCompareLiteral(&builder, expected, {input_data.get()},
+                           zero_error_spec_, &expected.shape());
 }
 
 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 74ded82..4e55b0d 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -83,25 +83,25 @@
       ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
   std::iota(input_vector.begin(), input_vector.end(), 0.0);
   auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
-  auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
+  auto input_literal = r1_literal.Reshape(spec.input_dims).ConsumeValueOrDie();
 
   XlaBuilder builder(TestName());
-  auto a = AddParam(*input_literal, &builder);
+  auto a = AddParam(input_literal, &builder);
   Rev(a, spec.reversal);
 
-  std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
+  Literal expected = input_literal.Clone();
   std::vector<int64> output_indices(spec.input_dims.size());
-  expected->EachCell<float>([&](absl::Span<const int64> indices, float) {
+  expected.EachCell<float>([&](absl::Span<const int64> indices, float) {
     for (int64 i = 0; i < indices.size(); ++i) {
       output_indices[i] = indices[i];
     }
-    float value = input_literal->Get<float>(indices);
+    float value = input_literal.Get<float>(indices);
     for (int64 dim : spec.reversal) {
       output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim];
     }
-    expected->Set<float>(output_indices, value);
+    expected.Set<float>(output_indices, value);
   });
-  ComputeAndCompareLiteral(&builder, *expected, {});
+  ComputeAndCompareLiteral(&builder, expected, {});
 }
 
 INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest,
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index e692b8c..091a5d2 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -38,7 +38,7 @@
 class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
  protected:
   // Sends the literal to the server and retrieves it back.
-  std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+  Literal RoundTripToServer(const Literal& original) {
     std::unique_ptr<GlobalData> data =
         client_->TransferToServer(original).ConsumeValueOrDie();
     return client_->Transfer(*data).ConsumeValueOrDie();
@@ -59,12 +59,12 @@
   std::unique_ptr<tensorflow::RandomAccessFile> f;
   TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
   PackedLiteralReader reader(f.release());
-  std::unique_ptr<Literal> actual =
+  Literal actual =
       reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
   EXPECT_TRUE(reader.IsExhausted());
 
-  EXPECT_EQ(42.0, actual->Get<float>({0}));
-  EXPECT_EQ(24.0, actual->Get<float>({1}));
+  EXPECT_EQ(42.0, actual.Get<float>({0}));
+  EXPECT_EQ(24.0, actual.Get<float>({1}));
 }
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
@@ -87,18 +87,17 @@
   std::unique_ptr<tensorflow::RandomAccessFile> f;
   TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
   PackedLiteralReader reader(f.release());
-  std::unique_ptr<Literal> actual =
-      reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
-          .ConsumeValueOrDie();
+  Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+                       .ConsumeValueOrDie();
   EXPECT_TRUE(reader.IsExhausted());
 
-  EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
-  EXPECT_EQ(24.0f, actual->Get<float>({0, 1}));
-  EXPECT_EQ(64.0f, actual->Get<float>({1, 0}));
-  EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+  EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+  EXPECT_EQ(24.0f, actual.Get<float>({0, 1}));
+  EXPECT_EQ(64.0f, actual.Get<float>({1, 0}));
+  EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
 
-  std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
-  EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+  Literal round_tripped = RoundTripToServer(actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
 }
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -121,18 +120,17 @@
   std::unique_ptr<tensorflow::RandomAccessFile> f;
   TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
   PackedLiteralReader reader(f.release());
-  std::unique_ptr<Literal> actual =
-      reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
-          .ConsumeValueOrDie();
+  Literal actual = reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+                       .ConsumeValueOrDie();
   EXPECT_TRUE(reader.IsExhausted());
 
-  EXPECT_EQ(42.0f, actual->Get<float>({0, 0}));
-  EXPECT_EQ(24.0f, actual->Get<float>({1, 0}));
-  EXPECT_EQ(64.0f, actual->Get<float>({0, 1}));
-  EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
+  EXPECT_EQ(42.0f, actual.Get<float>({0, 0}));
+  EXPECT_EQ(24.0f, actual.Get<float>({1, 0}));
+  EXPECT_EQ(64.0f, actual.Get<float>({0, 1}));
+  EXPECT_EQ(46.0f, actual.Get<float>({1, 1}));
 
-  std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
-  EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
+  Literal round_tripped = RoundTripToServer(actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(round_tripped, actual));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index a8193c2..cd5a531 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -39,69 +39,67 @@
   void RoundTripTest(const Literal& original) {
     std::unique_ptr<GlobalData> data =
         client_->TransferToServer(original).ConsumeValueOrDie();
-    std::unique_ptr<Literal> result =
-        client_->Transfer(*data).ConsumeValueOrDie();
-    EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
+    Literal result = client_->Transfer(*data).ConsumeValueOrDie();
+    EXPECT_TRUE(LiteralTestUtil::Equal(original, result));
   }
 };
 
 TEST_F(RoundTripTransferTest, R0S32) {
-  RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+  RoundTripTest(LiteralUtil::CreateR0<int32>(42));
 }
 
 TEST_F(RoundTripTransferTest, R0F32) {
-  RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+  RoundTripTest(LiteralUtil::CreateR0<float>(42.0));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len0) {
-  RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+  RoundTripTest(LiteralUtil::CreateR1<float>({}));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len2) {
-  RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+  RoundTripTest(LiteralUtil::CreateR1<float>({42.0, 64.0}));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len256) {
   std::vector<float> values(256);
   std::iota(values.begin(), values.end(), 1.0);
-  RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+  RoundTripTest(LiteralUtil::CreateR1<float>(values));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len1024) {
   std::vector<float> values(1024);
   std::iota(values.begin(), values.end(), 1.0);
-  RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+  RoundTripTest(LiteralUtil::CreateR1<float>(values));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len1025) {
   std::vector<float> values(1025);
   std::iota(values.begin(), values.end(), 1.0);
-  RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+  RoundTripTest(LiteralUtil::CreateR1<float>(values));
 }
 
 TEST_F(RoundTripTransferTest, R1F32_Len4096) {
   std::vector<float> values(4096);
   std::iota(values.begin(), values.end(), 1.0);
-  RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+  RoundTripTest(LiteralUtil::CreateR1<float>(values));
 }
 
 TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
-  RoundTripTest(
-      *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+  RoundTripTest(LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
 }
 
 TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
-  RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+  RoundTripTest(LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
 }
 
 TEST_F(RoundTripTransferTest, R3F32) {
   RoundTripTest(
-      *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
-                                     {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+      LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+                                    {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
 }
 
 TEST_F(RoundTripTransferTest, R4F32) {
-  RoundTripTest(*LiteralUtil::CreateR4<float>({{
+  RoundTripTest(LiteralUtil::CreateR4<float>({{
       {{10, 11, 12, 13}, {14, 15, 16, 17}},
       {{18, 19, 20, 21}, {22, 23, 24, 25}},
       {{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -109,36 +107,35 @@
 }
 
 TEST_F(RoundTripTransferTest, EmptyTuple) {
-  RoundTripTest(*LiteralUtil::MakeTuple({}));
+  RoundTripTest(LiteralUtil::MakeTuple({}));
 }
 
 TEST_F(RoundTripTransferTest, TupleOfR1F32) {
   RoundTripTest(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
-                               LiteralUtil::CreateR1<float>({3, 4}).get()}));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({1, 2}),
+                                        LiteralUtil::CreateR1<float>({3, 4})}));
 }
 
 TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
   RoundTripTest(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
-                               LiteralUtil::CreateR1<float>({3, 4}).get()}));
+      LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR1<float>({}),
+                                        LiteralUtil::CreateR1<float>({3, 4})}));
 }
 
 TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
-  RoundTripTest(
-      *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
-                               LiteralUtil::CreateR1<int>({2, 3}).get()}));
+  RoundTripTest(LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(1.0), LiteralUtil::CreateR1<int>({2, 3})}));
 }
 
 // Below two tests are added to identify the cost of large data transfers.
 TEST_F(RoundTripTransferTest, R2F32_Large) {
-  RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+  RoundTripTest(LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
 }
 
 TEST_F(RoundTripTransferTest, R4F32_Large) {
   Array4D<float> array4d(2, 2, 256, 256);
   array4d.FillWithMultiples(1.0f);
-  RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+  RoundTripTest(LiteralUtil::CreateR4FromArray4D<float>(array4d));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 07460a7..1dd937a 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -161,9 +161,9 @@
   ConvertElementType(a, F32);
 
   int64 value = 3LL << 35;
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
+  Literal a_literal = LiteralUtil::CreateR0<int64>(value);
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
   ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
                              {a_data.get()});
 }
@@ -225,20 +225,20 @@
 
 XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
   XlaBuilder builder(TestName());
-  std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
-  std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
-  std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+  Literal a_literal = LiteralUtil::CreateR0<float>(2.1f);
+  Literal b_literal = LiteralUtil::CreateR0<float>(5.5f);
+  Literal c_literal = LiteralUtil::CreateR0<float>(0.5f);
 
   std::unique_ptr<GlobalData> a_data =
-      client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+      client_->TransferToServer(a_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> b_data =
-      client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+      client_->TransferToServer(b_literal).ConsumeValueOrDie();
   std::unique_ptr<GlobalData> c_data =
-      client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+      client_->TransferToServer(c_literal).ConsumeValueOrDie();
 
-  XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
-  XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
-  XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+  XlaOp a = Parameter(&builder, 0, a_literal.shape(), "a");
+  XlaOp b = Parameter(&builder, 1, b_literal.shape(), "b");
+  XlaOp c = Parameter(&builder, 2, c_literal.shape(), "c");
   Mul(Mul(a, b), c);
 
   ComputeAndCompareR0<float>(&builder, 5.775f,
@@ -377,9 +377,9 @@
         auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
         auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
-                                client_->TransferToServer(*dividend_literal));
+                                client_->TransferToServer(dividend_literal));
         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
-                                client_->TransferToServer(*divisor_literal));
+                                client_->TransferToServer(divisor_literal));
         auto actual_literal =
             client_
                 ->ExecuteAndTransfer(div_computation,
@@ -388,7 +388,7 @@
                 .ConsumeValueOrDie();
         auto expected_literal =
             LiteralUtil::CreateR0<uint32>(dividend / divisor);
-        EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+        EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
       }
     }
   }
@@ -419,9 +419,9 @@
         auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
         auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
         TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
-                                client_->TransferToServer(*dividend_literal));
+                                client_->TransferToServer(dividend_literal));
         TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
-                                client_->TransferToServer(*divisor_literal));
+                                client_->TransferToServer(divisor_literal));
         auto actual_literal =
             client_
                 ->ExecuteAndTransfer(rem_computation,
@@ -430,7 +430,7 @@
                 .ConsumeValueOrDie();
         auto expected_literal =
             LiteralUtil::CreateR0<uint32>(dividend % divisor);
-        EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
+        EXPECT_TRUE(LiteralTestUtil::Equal(expected_literal, actual_literal));
       }
     }
   }
@@ -441,8 +441,8 @@
   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
   Rem(x, ConstantR0<int32>(&builder, 80000));
 
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
-  TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
+  Literal literal = LiteralUtil::CreateR0<int32>(87919);
+  TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
   ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
 }
 
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index 1858dce..d20dba0 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -62,13 +62,11 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
@@ -92,13 +90,12 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates =
       LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
@@ -123,13 +120,11 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
@@ -154,13 +149,11 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
@@ -185,13 +178,12 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+  Literal operand = LiteralUtil::CreateR2<float>(
       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({2, 1});
-  std::unique_ptr<Literal> updates =
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({2, 1});
+  Literal updates =
       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
@@ -216,13 +208,11 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({1, 1});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
@@ -247,13 +237,12 @@
       index_vector_dim=2
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+  Literal updates = LiteralUtil::CreateR3<int32>(
       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
@@ -277,15 +266,13 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
@@ -309,15 +296,13 @@
       index_vector_dim=0
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
                                     {{-4, 4}, {-5, 5}, {-6, 6}},  //
                                     {{-7, 7}, {-8, 8}, {-9, 9}}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+  Literal updates = LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
@@ -341,12 +326,11 @@
       index_vector_dim=0
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({1, 1});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({1, 1});
+  Literal updates = LiteralUtil::CreateR2<int32>({{10}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
@@ -370,13 +354,11 @@
       index_vector_dim=0
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+  Literal updates = LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, ZeroDimBounds) {
@@ -400,11 +382,10 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
-  std::unique_ptr<Literal> scatter_indices =
-      LiteralUtil::CreateR1<int32>({0, 2});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({0, 2});
+  Literal updates = LiteralUtil::CreateR2<int32>({{}, {}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
@@ -429,12 +410,11 @@
       index_vector_dim=2
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
-  std::unique_ptr<Literal> scatter_indices =
+  Literal operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+  Literal scatter_indices =
       LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
-  std::unique_ptr<Literal> updates =
-      LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal updates = LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
@@ -458,13 +438,13 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>(
       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+  Literal updates = LiteralUtil::CreateR3<int32>(
       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
@@ -488,13 +468,13 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+  Literal scatter_indices = LiteralUtil::CreateR2<uint32>(
       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+  Literal updates = LiteralUtil::CreateR3<int32>(
       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, NegativeIndex) {
@@ -518,13 +498,13 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand =
+  Literal operand =
       LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+  Literal scatter_indices = LiteralUtil::CreateR2<int32>(
       {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+  Literal updates = LiteralUtil::CreateR3<int32>(
       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, OneScalarIndex) {
@@ -548,12 +528,12 @@
       index_vector_dim=0
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+  Literal operand = LiteralUtil::CreateR3<int32>(
       {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
-  std::unique_ptr<Literal> updates =
+  Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+  Literal updates =
       LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, ScalarUpdate) {
@@ -577,10 +557,10 @@
       index_vector_dim=0
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+  Literal scatter_indices = LiteralUtil::CreateR0<int32>(1);
+  Literal updates = LiteralUtil::CreateR0<int32>(25);
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 XLA_TEST_F(ScatterTest, EmptyIndices) {
@@ -604,10 +584,10 @@
       index_vector_dim=1
 }
 )";
-  std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
-  std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
-  std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
-  RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+  Literal operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+  Literal scatter_indices = LiteralUtil::CreateR1<int32>({});
+  Literal updates = LiteralUtil::CreateR1<int32>({});
+  RunTest(hlo_text, &operand, &scatter_indices, &updates);
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index c9a58ae..a40c2d7 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -176,8 +176,8 @@
   XlaBuilder builder(TestName());
   auto original = ConstantR4FromArray4D(&builder, values);
   Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
-  ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
-                           &expected_literal->shape());
+  ComputeAndCompareLiteral(&builder, expected_literal, {}, ErrorSpec(0.000001),
+                           &expected_literal.shape());
 }
 
 struct R1Spec {
@@ -201,7 +201,7 @@
     auto literal = LiteralUtil::CreateR1<NativeT>(input);
 
     XlaBuilder builder(TestName());
-    auto original = Parameter(&builder, 0, literal->shape(), "p0");
+    auto original = Parameter(&builder, 0, literal.shape(), "p0");
     Slice(original, {spec.slice_start}, {spec.slice_limit},
           {spec.slice_stride});
 
@@ -213,7 +213,7 @@
     }
 
     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
-                            client_->TransferToServer(*literal));
+                            client_->TransferToServer(literal));
     ComputeAndCompareR1<NativeT>(&builder, expected, {arg.get()});
   }
 };
@@ -376,11 +376,11 @@
       input, LayoutUtil::MakeLayout(spec.layout));
 
   XlaBuilder builder(TestName());
-  auto a = Parameter(&builder, 0, literal->shape(), "p0");
+  auto a = Parameter(&builder, 0, literal.shape(), "p0");
   Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
 
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
-                          client_->TransferToServer(*literal));
+                          client_->TransferToServer(literal));
   std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
       input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
   ComputeAndCompareR2<int32>(&builder, *expected, {arg.get()});
@@ -467,9 +467,9 @@
     XlaBuilder builder(TestName());
     auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
         values, LayoutUtil::MakeLayout(spec.input_layout));
-    auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
+    auto parameter = Parameter(&builder, 0, literal.shape(), "p0");
     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
-                            client_->TransferToServer(*literal));
+                            client_->TransferToServer(literal));
     Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
     ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
   }
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index c20a7c8..5155f0c 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -116,13 +116,14 @@
 // array. This is uniqueness is best-effort only. Some types (half and bfloat16)
 // are not supported and uniqueness cannot be guaranteed if the number of
 // elements exceeds the number of different values supported by the type.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
-    const Shape& shape, std::minstd_rand0* engine, bool no_duplicates) {
+StatusOr<Literal> MakeFakeLiteralInternal(const Shape& shape,
+                                          std::minstd_rand0* engine,
+                                          bool no_duplicates) {
   if (ShapeUtil::IsTuple(shape)) {
-    std::vector<std::unique_ptr<Literal>> elements;
+    std::vector<Literal> elements;
     for (const Shape& element_shape : shape.tuple_shapes()) {
       TF_ASSIGN_OR_RETURN(
-          std::unique_ptr<Literal> element,
+          Literal element,
           MakeFakeLiteralInternal(element_shape, engine, no_duplicates));
       elements.push_back(std::move(element));
     }
@@ -131,60 +132,52 @@
   if (engine == nullptr) {
     return Literal::CreateFromShape(shape);
   }
-  auto literal = absl::make_unique<Literal>(shape);
+  Literal literal(shape);
   switch (shape.element_type()) {
     case BF16:
-      PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine,
+      PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
                                                     no_duplicates);
       break;
     case F16:
-      PopulateWithRandomFloatingPointData<half>(literal.get(), engine,
+      PopulateWithRandomFloatingPointData<half>(&literal, engine,
                                                 no_duplicates);
       break;
     case F32:
-      PopulateWithRandomFloatingPointData<float>(literal.get(), engine,
+      PopulateWithRandomFloatingPointData<float>(&literal, engine,
                                                  no_duplicates);
       break;
     case F64:
-      PopulateWithRandomFloatingPointData<double>(literal.get(), engine,
+      PopulateWithRandomFloatingPointData<double>(&literal, engine,
                                                   no_duplicates);
       break;
     case S8:
-      PopulateWithRandomIntegralData<int8>(literal.get(), engine,
-                                           no_duplicates);
+      PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
       break;
     case U8:
-      PopulateWithRandomIntegralData<uint8>(literal.get(), engine,
-                                            no_duplicates);
+      PopulateWithRandomIntegralData<uint8>(&literal, engine, no_duplicates);
       break;
     case S16:
-      PopulateWithRandomIntegralData<int16>(literal.get(), engine,
-                                            no_duplicates);
+      PopulateWithRandomIntegralData<int16>(&literal, engine, no_duplicates);
       break;
     case U16:
-      PopulateWithRandomIntegralData<uint16>(literal.get(), engine,
-                                             no_duplicates);
+      PopulateWithRandomIntegralData<uint16>(&literal, engine, no_duplicates);
       break;
     case S32:
-      PopulateWithRandomIntegralData<int32>(literal.get(), engine,
-                                            no_duplicates);
+      PopulateWithRandomIntegralData<int32>(&literal, engine, no_duplicates);
       break;
     case U32:
-      PopulateWithRandomIntegralData<uint32>(literal.get(), engine,
-                                             no_duplicates);
+      PopulateWithRandomIntegralData<uint32>(&literal, engine, no_duplicates);
       break;
     case S64:
-      PopulateWithRandomIntegralData<int64>(literal.get(), engine,
-                                            no_duplicates);
+      PopulateWithRandomIntegralData<int64>(&literal, engine, no_duplicates);
       break;
     case U64:
-      PopulateWithRandomIntegralData<uint64>(literal.get(), engine,
-                                             no_duplicates);
+      PopulateWithRandomIntegralData<uint64>(&literal, engine, no_duplicates);
       break;
     case PRED: {
       std::uniform_int_distribution<int> generator(0, 1);
       TF_CHECK_OK(
-          literal->Populate<bool>([&](absl::Span<const int64> /*indices*/) {
+          literal.Populate<bool>([&](absl::Span<const int64> /*indices*/) {
             return generator(*engine);
           }));
       break;
@@ -236,8 +229,8 @@
 
 // Generate random values that are constrained to the input_shape minus the
 // output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomIndex(absl::Span<const int64> index_space,
-                                         std::minstd_rand0* engine) {
+Literal MakeRandomIndex(absl::Span<const int64> index_space,
+                        std::minstd_rand0* engine) {
   std::vector<int32> start_indices(index_space.size());
   if (engine != nullptr) {
     for (int i = 0; i < index_space.size(); ++i) {
@@ -293,7 +286,7 @@
 // no constrained uses in the dataflow graph.  If such constraints exist,
 // generate a constrained literal (either bounded in the case of indices, or
 // zero in the case of init_values for reductions).
-StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
+StatusOr<Literal> CreateLiteralForConstrainedUses(
     const absl::Span<HloInstruction* const> constrained_uses,
     const HloInstruction& param, std::minstd_rand0* engine) {
   std::vector<int64> index_space;
@@ -358,9 +351,9 @@
   } else if (needs_constant) {
     switch (constant_type) {
       case ConstantType::kZero:
-        return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
+        return LiteralUtil::Zero(param.shape().element_type());
       case ConstantType::kOne:
-        return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
+        return LiteralUtil::One(param.shape().element_type());
       case ConstantType::kUnknown:
         // We want the identity element for the computation, but we don't really
         // know what it is - so any value we generate will be just as wrong.
@@ -374,34 +367,33 @@
 
 // Given a module entry parameter, use the dataflow analysis to see if a
 // special case literal must be created, or if we can generate fake data.
-StatusOr<std::unique_ptr<Literal>> MakeConstrainedArgument(
-    const HloDataflowAnalysis& dataflow, const HloInstruction& param,
-    std::minstd_rand0* engine) {
+StatusOr<Literal> MakeConstrainedArgument(const HloDataflowAnalysis& dataflow,
+                                          const HloInstruction& param,
+                                          std::minstd_rand0* engine) {
   const auto constrained_uses = FindConstrainedUses(dataflow, param);
   return CreateLiteralForConstrainedUses(constrained_uses, param, engine);
 }
 
 }  // namespace
 
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
-                                                   bool pseudo_random) {
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape, bool pseudo_random) {
   auto engine =
       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
   return MakeFakeLiteralInternal(shape, engine.get(), /*no_duplicates=*/false);
 }
 
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
-    HloModule* const module, bool pseudo_random) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+                                                 bool pseudo_random) {
   auto engine =
       pseudo_random ? absl::make_unique<std::minstd_rand0>() : nullptr;
   return MakeFakeArguments(module, engine.get());
 }
 
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
-    HloModule* const module, std::minstd_rand0* engine) {
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+                                                 std::minstd_rand0* engine) {
   TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
   const auto params = module->entry_computation()->parameter_instructions();
-  std::vector<std::unique_ptr<Literal>> arguments(params.size());
+  std::vector<Literal> arguments(params.size());
   for (int i = 0; i < params.size(); ++i) {
     arguments[i] =
         MakeConstrainedArgument(*dataflow, *params[i], engine).ValueOrDie();
@@ -417,4 +409,18 @@
       .status();
 }
 
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+                                                      HloInstruction* lhs,
+                                                      HloInstruction* rhs) {
+  CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+  CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+  PrecisionConfig precision_config;
+  precision_config.mutable_operand_precision()->Resize(
+      2, PrecisionConfig::DEFAULT);
+  DotDimensionNumbers dot_dimension_numbers;
+  dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+  dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+  return absl::make_unique<HloDotInstruction>(
+      shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 7790737..b3c8a73 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -24,10 +24,10 @@
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/platform.h"
 
 namespace xla {
 
@@ -57,8 +57,8 @@
 // Generates fake data in a literal of the given shape, or returns an error
 // status if the element type is currently unhandled for fake data
 // generation. See below for documentation of pseudo_random.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape,
-                                                   bool pseudo_random = true);
+StatusOr<Literal> MakeFakeLiteral(const Shape& shape,
+                                  bool pseudo_random = true);
 
 // Generates a vector of arguments containing fake data. The number, shape and
 // layout of the arguments is appropriate for given HLO module.
@@ -84,20 +84,26 @@
 // TODO(b/79942829): Make interesting argument generation fast enough that using
 // pseudo_random does not save any noticeable amount of time so that the
 // parameter can be removed.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
-    HloModule* const module, bool pseudo_random = true);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+                                                 bool pseudo_random = true);
 
 // Overload which accepts a random number generator. This enables generation of
 // different random values with sequential calls to MakeFakeArguments by reusing
 // the same generator.
-StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
-    HloModule* const module, std::minstd_rand0* engine);
+StatusOr<std::vector<Literal>> MakeFakeArguments(HloModule* const module,
+                                                 std::minstd_rand0* engine);
 
 // Check that a given module satisfies various constraints before trying to
 // execute it.
 Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
                        bool allow_mixed_precision);
 
+// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of
+// the LHS with dimension 0 of the RHS with no batch dimensions.
+// Both LHS and the RHS must be of rank 2.
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+                                                      HloInstruction* lhs,
+                                                      HloInstruction* rhs);
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 322c8ef..181e5cb 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -85,10 +85,10 @@
       ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
     })")
                     .ValueOrDie();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
                           MakeFakeArguments(module.get()));
   ASSERT_EQ(args.size(), 3);
-  const Literal& index_arg = *args[0];
+  const Literal& index_arg = args[0];
 
   EXPECT_EQ(index_arg.Get<int32>({0}), 0);
 
@@ -114,10 +114,10 @@
       ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
     })")
                     .ValueOrDie();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
                           MakeFakeArguments(module.get()));
   ASSERT_EQ(args.size(), 5);
-  const Literal& index_arg = *args[0];
+  const Literal& index_arg = args[0];
 
   EXPECT_EQ(index_arg.Get<int32>({0}), 0);
 
@@ -140,10 +140,10 @@
 }
 )")
                     .ValueOrDie();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
                           MakeFakeArguments(module.get()));
   ASSERT_EQ(args.size(), 2);
-  const Literal& key_arg = *args[0];
+  const Literal& key_arg = args[0];
 
   tensorflow::gtl::FlatSet<uint32> key_set;
   for (const float& value : key_arg.data<float>()) {
@@ -163,10 +163,10 @@
 }
 )")
                     .ValueOrDie();
-  TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+  TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
                           MakeFakeArguments(module.get()));
   ASSERT_EQ(args.size(), 2);
-  const Literal& key_arg = *args[0];
+  const Literal& key_arg = args[0];
 
   tensorflow::gtl::FlatSet<int32> key_set;
   for (const int32& value : key_arg.data<int32>()) {
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index c7eb9e2..b34fd0f 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -34,9 +34,8 @@
 
   module->AddEntryComputation(builder.Build());
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
-                          Execute(std::move(module), {}));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+  TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
 }
 
 XLA_TEST_F(TokenHloTest, TokenTree) {
@@ -50,9 +49,8 @@
 
   module->AddEntryComputation(builder.Build());
 
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
-                          Execute(std::move(module), {}));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+  TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
 }
 
 XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
@@ -193,9 +191,8 @@
         std::unique_ptr<HloModule> module,
         HloRunner::CreateModuleFromString(module_string, debug_options));
     auto arg = LiteralUtil::CreateR0<bool>(true);
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
-                            Execute(std::move(module), {arg.get()}));
-    EXPECT_EQ(42, result->Get<int32>({}));
+    TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+    EXPECT_EQ(42, result.Get<int32>({}));
   }
 
   {
@@ -204,9 +201,8 @@
         std::unique_ptr<HloModule> module,
         HloRunner::CreateModuleFromString(module_string, debug_options));
     auto arg = LiteralUtil::CreateR0<bool>(false);
-    TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
-                            Execute(std::move(module), {arg.get()}));
-    EXPECT_EQ(7, result->Get<int32>({}));
+    TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
+    EXPECT_EQ(7, result.Get<int32>({}));
   }
 }
 
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 125513d..d6641d2 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -69,90 +69,90 @@
 };
 
 XLA_TEST_F(TransferManagerTest, TransferR0U32) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
-  const Shape& shape = literal->shape();
+  Literal literal = LiteralUtil::CreateR0<uint32>(42);
+  const Shape& shape = literal.shape();
   auto device_buffer = AllocateDeviceBuffer(shape);
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
+  LiteralTestUtil::ExpectR0Equal<uint32>(42, result);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferR1F32) {
-  std::unique_ptr<Literal> literal =
+  Literal literal =
       LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
-  const Shape& shape = literal->shape();
+  const Shape& shape = literal.shape();
   auto device_buffer = AllocateDeviceBuffer(shape);
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
   LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
-                                        *result);
+                                        result);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
   std::vector<float> test_vector(1024 * 1024);
   std::iota(test_vector.begin(), test_vector.end(), 0);
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
-  const Shape& shape = literal->shape();
+  Literal literal = LiteralUtil::CreateR1<float>(test_vector);
+  const Shape& shape = literal.shape();
   auto device_buffer = AllocateDeviceBuffer(shape);
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
+  LiteralTestUtil::ExpectR1Equal<float>(test_vector, result);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferR1U8) {
   const char* test_string = "0123456789abcdef";
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
-  const Shape& shape = literal->shape();
+  Literal literal = LiteralUtil::CreateR1U8(test_string);
+  const Shape& shape = literal.shape();
   auto device_buffer = AllocateDeviceBuffer(shape);
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_EQ(result->GetR1U8AsString(), test_string);
+  EXPECT_EQ(result.GetR1U8AsString(), test_string);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferR2F32) {
-  std::unique_ptr<Literal> literal =
+  Literal literal =
       LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
-  const Shape& shape = literal->shape();
+  const Shape& shape = literal.shape();
   auto device_buffer = AllocateDeviceBuffer(shape);
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+      {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
 }
 
 XLA_TEST_F(TransferManagerTest,
            TransferR2F32AndChangeLayoutTransferringToDevice) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
+  Literal literal = LiteralUtil::CreateR2WithLayout<float>(
       {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
   const Shape ondevice_shape =
       ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -160,101 +160,99 @@
 
   // Round trip literal through device. Set the on-device layout to something
   // different than the literal layout.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
   EXPECT_FALSE(
-      LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
+      LayoutUtil::Equal(result.shape().layout(), literal.shape().layout()));
   LiteralTestUtil::ExpectR2Equal<float>(
-      {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
+      {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, result);
 }
 
 XLA_TEST_F(TransferManagerTest, TransferTuple) {
-  std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(123.0f).get(),
-       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
-       LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
-  auto device_buffer = AllocateDeviceBuffer(literal->shape());
+  Literal literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(123.0f),
+       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+       LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})});
+  auto device_buffer = AllocateDeviceBuffer(literal.shape());
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
-  std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
-  auto device_buffer = AllocateDeviceBuffer(literal->shape());
+  Literal literal = LiteralUtil::MakeTuple({});
+  auto device_buffer = AllocateDeviceBuffer(literal.shape());
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
-  std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(123.0f).get(),
-       LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
-            LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
-           .get(),
-       LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
-  auto device_buffer = AllocateDeviceBuffer(literal->shape());
+  Literal literal = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(123.0f),
+       LiteralUtil::MakeTupleFromSlices(
+           {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+            LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+       LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+  auto device_buffer = AllocateDeviceBuffer(literal.shape());
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
-  std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
+  Literal literal = LiteralUtil::CreateR1<complex64>(
       {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
-  auto device_buffer = AllocateDeviceBuffer(literal->shape());
+  auto device_buffer = AllocateDeviceBuffer(literal.shape());
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
-  std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+  Literal literal = LiteralUtil::MakeTupleFromSlices(
       {LiteralUtil::CreateR1<complex64>(
-           {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
-           .get(),
-       LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
-       LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
-  auto device_buffer = AllocateDeviceBuffer(literal->shape());
+           {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}),
+       LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}),
+       LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f))});
+  auto device_buffer = AllocateDeviceBuffer(literal.shape());
 
   // Round trip literal through device.
-  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+  ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                           device_buffer));
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal, result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
@@ -264,54 +262,52 @@
   // supported.
   auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
   TF_ASSERT_OK_AND_ASSIGN(
-      std::unique_ptr<Literal> result,
+      Literal result,
       transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+  EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateToken(), result));
 }
 
 XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
   const int64 kIterationCount = 5000;
-  std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(123.0f).get(),
-       LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
-            LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
-           .get(),
-       LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
-  std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(456.0f).get(),
-       LiteralUtil::MakeTuple(
-           {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
-            LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
-           .get(),
-       LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
+  Literal literal1 = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(123.0f),
+       LiteralUtil::MakeTupleFromSlices(
+           {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}),
+            LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f})}),
+       LiteralUtil::CreateR1<float>({-10.0f, 123.0f})});
+  Literal literal2 = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(456.0f),
+       LiteralUtil::MakeTupleFromSlices(
+           {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}),
+            LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f})}),
+       LiteralUtil::CreateR1<float>({-98.0f, 153.0f})});
 
-  auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
-  auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+  auto device_buffer1 = AllocateDeviceBuffer(literal1.shape());
+  auto device_buffer2 = AllocateDeviceBuffer(literal2.shape());
 
   auto stream1 = stream_;
   auto stream2 = stream_->GetOrCreateSubStream();
 
-  std::unique_ptr<Literal> result1, result2;
+  Literal result1, result2;
 
   // Round trip literals through device in multiple streams asynchronously.
   for (int i = 0; i < kIterationCount; ++i) {
-    ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+    ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, literal1,
                                                             device_buffer1));
-    ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+    ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, literal2,
                                                             device_buffer2));
     TF_ASSERT_OK_AND_ASSIGN(
-        std::unique_ptr<Literal> this_result1,
+        Literal this_result1,
         transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
     TF_ASSERT_OK_AND_ASSIGN(
-        std::unique_ptr<Literal> this_result2,
+        Literal this_result2,
         transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
     result1 = std::move(this_result1);
     result2 = std::move(this_result2);
   }
 
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal1, result1));
+  EXPECT_TRUE(LiteralTestUtil::Equal(literal2, result2));
 }
 
 class TransferDeviceToHostBenchmark : public TransferManagerTest {
@@ -323,20 +319,19 @@
     tensorflow::testing::StopTiming();
     SetUp();
 
-    std::vector<std::unique_ptr<Literal>> tuple_elements;
+    std::vector<Literal> tuple_elements;
     for (int i = 0; i < num_tuple_elements; ++i) {
       tuple_elements.push_back(
           LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
     }
-    std::unique_ptr<Literal> literal =
-        LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
-    auto device_buffer = AllocateDeviceBuffer(literal->shape());
-    TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+    Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+    auto device_buffer = AllocateDeviceBuffer(literal.shape());
+    TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                            device_buffer));
     tensorflow::testing::StartTiming();
     for (int i = 0; i < iters; ++i) {
       TF_ASSERT_OK_AND_ASSIGN(
-          std::unique_ptr<Literal> result,
+          Literal result,
           transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
     }
     tensorflow::testing::StopTiming();
@@ -355,17 +350,16 @@
     tensorflow::testing::StopTiming();
     SetUp();
 
-    std::vector<std::unique_ptr<Literal>> tuple_elements;
+    std::vector<Literal> tuple_elements;
     for (int i = 0; i < num_tuple_elements; ++i) {
       tuple_elements.push_back(
           LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
     }
-    std::unique_ptr<Literal> literal =
-        LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
-    auto device_buffer = AllocateDeviceBuffer(literal->shape());
+    Literal literal = LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+    auto device_buffer = AllocateDeviceBuffer(literal.shape());
     tensorflow::testing::StartTiming();
     for (int i = 0; i < iters; ++i) {
-      TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+      TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, literal,
                                                              device_buffer));
     }
     tensorflow::testing::StopTiming();
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index f2b3b49..619d2a3 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -51,13 +51,13 @@
       {1.1f, 2.2f, 3.5f},  // row 0
       {4.8f, 5.0f, 6.7f},  // row 1
   };
-  auto value = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(constant_scalar).get(),
-       LiteralUtil::CreateR1<float>(constant_vector).get(),
-       LiteralUtil::CreateR2<float>(constant_matrix).get()});
+  auto value = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(constant_scalar),
+       LiteralUtil::CreateR1<float>(constant_vector),
+       LiteralUtil::CreateR2<float>(constant_matrix)});
 
-  ConstantLiteral(&builder, *value);
-  ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+  ConstantLiteral(&builder, value);
+  ComputeAndCompareTuple(&builder, value, {}, error_spec_);
 }
 
 // Tests a tuple made of scalar constants.
@@ -66,12 +66,12 @@
 
   const float constant_scalar1 = 7.3f;
   const float constant_scalar2 = 1.2f;
-  auto value = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
-       LiteralUtil::CreateR0<float>(constant_scalar2).get()});
+  auto value = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(constant_scalar1),
+       LiteralUtil::CreateR0<float>(constant_scalar2)});
 
-  ConstantLiteral(&builder, *value);
-  ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
+  ConstantLiteral(&builder, value);
+  ComputeAndCompareTuple(&builder, value, {}, error_spec_);
 }
 
 // Tests the creation of tuple data.
@@ -88,11 +88,11 @@
                    ConstantR1<float>(&builder, constant_vector),
                    ConstantR2<float>(&builder, constant_matrix)});
 
-  auto expected = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR0<float>(constant_scalar).get(),
-       LiteralUtil::CreateR1<float>(constant_vector).get(),
-       LiteralUtil::CreateR2<float>(constant_matrix).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(constant_scalar),
+       LiteralUtil::CreateR1<float>(constant_vector),
+       LiteralUtil::CreateR2<float>(constant_matrix)});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 // Tests the creation of tuple data.
@@ -102,10 +102,9 @@
   Tuple(&builder,
         {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
 
-  auto expected =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
-                              LiteralUtil::CreateR1<float>({}).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR0<float>(7.0), LiteralUtil::CreateR1<float>({})});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 // Tests the creation of an empty tuple.
@@ -113,7 +112,7 @@
   XlaBuilder builder(TestName());
   Tuple(&builder, {});
   auto expected = LiteralUtil::MakeTuple({});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 // Trivial test for extracting a tuple element with GetTupleElement.
@@ -196,10 +195,10 @@
                        ConstantR2<float>(&builder, constant_matrix)});
   Tuple(&builder,
         {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
-  auto expected = LiteralUtil::MakeTuple(
-      {LiteralUtil::CreateR2<float>(constant_matrix).get(),
-       LiteralUtil::CreateR1<float>(constant_vector).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR2<float>(constant_matrix),
+       LiteralUtil::CreateR1<float>(constant_vector)});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
@@ -218,11 +217,11 @@
     auto v1_v2 = Tuple(&b, {v1_gt, v2_gt});  // {false, true}
     auto v2_v1 = Tuple(&b, {v2_gt, v1_gt});  // {true, false}
     Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
-    auto expected =
-        LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
-                                LiteralUtil::CreateR0<bool>(!direction).get()});
+    auto expected = LiteralUtil::MakeTupleFromSlices(
+        {LiteralUtil::CreateR0<bool>(direction),
+         LiteralUtil::CreateR0<bool>(!direction)});
 
-    ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
+    ComputeAndCompareTuple(&b, expected, {v1_data.get(), v2_data.get()},
                            error_spec_);
   }
 }
@@ -287,10 +286,9 @@
                                   ConstantR1<float>(&builder, vec1)});
 
   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
-  auto expected =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
-                              LiteralUtil::CreateR1<float>(vec1).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, TuplesInAMap) {
@@ -332,10 +330,9 @@
                                   ConstantR1<float>(&builder, vec1)});
 
   Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
-  auto expected =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
-                              LiteralUtil::CreateR1<float>(vec2).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<float>(vec1), LiteralUtil::CreateR1<float>(vec2)});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
@@ -408,10 +405,9 @@
 
   Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
 
-  auto expected =
-      LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
-                              LiteralUtil::CreateR1<float>(vec1).get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::CreateR1<float>(vec2), LiteralUtil::CreateR1<float>(vec1)});
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, NestedTuples) {
@@ -423,12 +419,11 @@
   auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
   auto expected_s = LiteralUtil::CreateR0<float>(42.0);
   auto expected_inner_tuple =
-      LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+      LiteralUtil::MakeTuple({&expected_v1, &expected_s});
   auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
-  auto expected =
-      LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+  auto expected = LiteralUtil::MakeTuple({&expected_inner_tuple, &expected_v2});
 
-  ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+  ComputeAndCompareTuple(&builder, expected, {}, error_spec_);
 }
 
 XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
@@ -446,14 +441,12 @@
 
   std::unique_ptr<GlobalData> data =
       client_
-          ->TransferToServer(*LiteralUtil::MakeTuple({
-              LiteralUtil::MakeTuple(
-                  {
-                      LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
-                      LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
-                  })
-                  .get(),
-              LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+          ->TransferToServer(LiteralUtil::MakeTupleFromSlices({
+              LiteralUtil::MakeTupleFromSlices({
+                  LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}),
+                  LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}),
+              }),
+              LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}),
           }))
           .ConsumeValueOrDie();
 
@@ -484,40 +477,36 @@
 
   std::unique_ptr<GlobalData> arg0 =
       client_
-          ->TransferToServer(*LiteralUtil::MakeTuple(
-              {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
-               LiteralUtil::MakeTuple(
-                   {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
-                        .get(),
+          ->TransferToServer(LiteralUtil::MakeTupleFromSlices(
+              {LiteralUtil::CreateR0<complex64>({1, 2}),
+               LiteralUtil::MakeTupleFromSlices(
+                   {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}}),
                     LiteralUtil::CreateR2<complex64>(
                         {{{100, 200}, {300, 400}},
                          {{1000, 2000}, {3000, 4000}},
-                         {{10000, 20000}, {30000, 40000}}})
-                        .get()})
-                   .get()}))
+                         {{10000, 20000}, {30000, 40000}}})})}))
           .ConsumeValueOrDie();
   std::unique_ptr<GlobalData> arg1 =
       client_
           ->TransferToServer(
-              *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
+              LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
           .ConsumeValueOrDie();
   auto sum =
       LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
                                         {{1011, 2022}, {3031, 4042}},
                                         {{10011, 20022}, {30031, 40042}}});
-  auto prod = absl::make_unique<Literal>(sum->shape());
-  ASSERT_TRUE(prod->Populate<complex64>(
-                      [&sum](absl::Span<const int64> indexes) {
-                        return sum->Get<complex64>(indexes) *
-                               (indexes[indexes.size() - 1] == 0
-                                    ? complex64(1, 2)
-                                    : complex64(1, -2));
-                      })
+  Literal prod(sum.shape());
+  ASSERT_TRUE(prod.Populate<complex64>([&sum](absl::Span<const int64> indexes) {
+                    return sum.Get<complex64>(indexes) *
+                           (indexes[indexes.size() - 1] == 0
+                                ? complex64(1, 2)
+                                : complex64(1, -2));
+                  })
                   .ok());
-  auto expected = LiteralUtil::MakeTuple(
-      {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
-       LiteralUtil::CreateR0<complex64>({123, 456}).get()});
-  ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
+  auto expected = LiteralUtil::MakeTupleFromSlices(
+      {LiteralUtil::MakeTupleFromSlices({prod, sum}),
+       LiteralUtil::CreateR0<complex64>({123, 456})});
+  ComputeAndCompareTuple(&builder, expected, {arg0.get(), arg1.get()},
                          error_spec_);
 }
 
@@ -541,10 +530,10 @@
           .ValueOrDie();
   auto param =
       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
-  auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
+  auto result = ExecuteNoHloPasses(std::move(module), {&param});
   EXPECT_TRUE(LiteralTestUtil::Equal(
-      *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
-      *result));
+      LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+      result));
 }
 
 // Disabled on interpreter due to lack of outfeed.
@@ -581,16 +570,15 @@
       tensorflow::Env::Default()->StartThread(
           tensorflow::ThreadOptions(), "execute_thread", [&] {
             TF_EXPECT_OK(Execute(std::move(module),
-                                 {param0.get(), param1.get(), param1.get(),
-                                  param0.get(), param4.get()})
+                                 {&param0, &param1, &param1, &param0, &param4})
                              .status());
           }));
   auto expected =
       LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({2, 3}));
-  auto literal = Literal::CreateFromShape(expected->shape());
+  auto literal = Literal::CreateFromShape(expected.shape());
   TF_EXPECT_OK(backend().transfer_manager()->TransferLiteralFromOutfeed(
-      backend().default_stream_executor(), expected->shape(), *literal));
-  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *literal));
+      backend().default_stream_executor(), expected.shape(), literal));
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 8f80a9f..4fbd7f2 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -100,9 +100,9 @@
                                               {-inf<float>(), 0}});
   Abs(arg);
 
-  std::unique_ptr<Literal> expected =
+  Literal expected =
       LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
-  ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+  ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
 }
 
 template <>
@@ -113,9 +113,9 @@
       {{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
   Sign(arg);
 
-  std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
+  Literal expected = LiteralUtil::CreateR1<complex64>(
       {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
-  ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+  ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
 }
 
 template <>
@@ -127,9 +127,8 @@
   auto abs = Abs(arg);
   Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
-  ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+  Literal expected = LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
+  ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
 }
 
 XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
@@ -172,9 +171,8 @@
   Add(sgnc, ConvertElementType(
                 Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
 
-  std::unique_ptr<Literal> expected =
-      LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
-  ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
+  Literal expected = LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
+  ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f));
 }
 
 XLA_TEST_F(UnaryOpTest, SignTestR1) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf186..7abd865 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -348,9 +348,9 @@
   // have all reached 2.0.
   auto expected_data =
       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
-  auto expected = LiteralUtil::MakeTuple({expected_data.get()});
-  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+  auto expected = LiteralUtil::MakeTuple({&expected_data});
+  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
 TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
@@ -401,11 +401,10 @@
   auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
   auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
   auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
-  auto expected =
-      LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
-                              expected_w3.get(), expected_w1.get()});
-  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+  auto expected = LiteralUtil::MakeTuple(
+      {&expected_counter, &expected_w2, &expected_w3, &expected_w1});
+  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
 TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
@@ -510,10 +509,9 @@
   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
   auto expected_data = LiteralUtil::CreateR1<float>(
       {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
-  auto expected =
-      LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
-  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+  auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
 TEST_F(WhileTest, WhileWithPredicateTupleResult) {
@@ -557,9 +555,9 @@
 
   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
   auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
-  auto expected = LiteralUtil::MakeTuple(
-      {expected_counter.get(), expected_predicate.get()});
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
+  auto expected =
+      LiteralUtil::MakeTuple({&expected_counter, &expected_predicate});
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0));
 }
 
 TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
@@ -602,10 +600,9 @@
 
   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
   auto expected_data = LiteralUtil::CreateR0<int32>(7);
-  auto expected =
-      LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
-  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+  auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
 // Tests two while nodes when the result type T is a Tuple and the second
@@ -886,10 +883,9 @@
   auto expected_counter = LiteralUtil::CreateR0<int32>(5);
   auto expected_data = LiteralUtil::CreateR1<float>(
       {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
-  auto expected =
-      LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
-  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
-  ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+  auto expected = LiteralUtil::MakeTuple({&expected_counter, &expected_data});
+  VLOG(2) << "expected = " << ShapeUtil::HumanString(expected.shape());
+  ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.0001));
 }
 
 // Tests a while node when the result type T is a vector of S32.
@@ -977,11 +973,11 @@
 
   auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
   auto expected =
-      LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
+      LiteralUtil::MakeTuple({&expected_element, &expected_element});
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> parameter_data,
-      client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
-  ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+      client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
+  ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
                          ErrorSpec(1e-6));
 }
 
@@ -1005,7 +1001,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> parameter_data,
-      client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
+      client_->TransferToServer(LiteralUtil::CreateR1<float>({42, 42})));
   ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
                              ErrorSpec(1e-6));
 }
@@ -1031,7 +1027,7 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> parameter_data,
-      client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
+      client_->TransferToServer(LiteralUtil::CreateR0<float>(42)));
   ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
                              ErrorSpec(1e-6));
 }
@@ -1070,12 +1066,12 @@
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<GlobalData> parameter_data,
-      client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
+      client_->TransferToServer(LiteralUtil::CreateR0<int32>(1)));
 
   auto add1 = LiteralUtil::CreateR0<int32>(15);
   auto add2 = LiteralUtil::CreateR0<int32>(16);
-  auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
-  ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
+  auto expected = LiteralUtil::MakeTuple({&add1, &add2});
+  ComputeAndCompareTuple(&outer, expected, {parameter_data.get()},
                          ErrorSpec(1e-6));
 }
 
@@ -1228,7 +1224,7 @@
   GetTupleElement(while_instruction, 3);
 
   TF_ASSERT_OK_AND_ASSIGN(
-      auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+      auto param_value, client_->TransferToServer(LiteralUtil::CreateR2<float>(
                             {{1.0, 2.0}, {-1.0, -2.0}})));
 
   ComputeAndCompareR2<float>(
@@ -1258,9 +1254,9 @@
   XlaBuilder builder(TestName());
   While(condition, body, ConstantR0<int32>(&builder, 0));
 
-  TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
-  TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
-  TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+  TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+  TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(true)));
+  TF_ASSERT_OK(client_->TransferToInfeed(LiteralUtil::CreateR0<bool>(false)));
 
   ComputeAndCompareR0<int32>(&builder, 2, {});
 }
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 7fd4294..db5a824 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -144,14 +144,14 @@
       transfer_manager->AllocateScopedShapedBuffer(
           lhs_arg_shape, allocator, backend->default_device_ordinal()));
   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
-      stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+      stream_ptr.get(), Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
 
   TF_ASSERT_OK_AND_ASSIGN(
       ScopedShapedBuffer rhs_arg,
       transfer_manager->AllocateScopedShapedBuffer(
           rhs_arg_shape, allocator, backend->default_device_ordinal()));
   TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
-      stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+      stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
 
   TF_ASSERT_OK_AND_ASSIGN(
       std::unique_ptr<LocalExecutable> local_executable,
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 442e663..cdde88c 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -39,8 +39,7 @@
 
 namespace xla {
 
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
-    absl::string_view path) {
+StatusOr<Literal> TextLiteralReader::ReadPath(absl::string_view path) {
   CHECK(!absl::EndsWith(path, ".gz"))
       << "TextLiteralReader no longer supports reading .gz files";
   std::unique_ptr<tensorflow::RandomAccessFile> file;
@@ -57,7 +56,7 @@
 TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
     : file_(file) {}
 
-StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+StatusOr<Literal> TextLiteralReader::ReadAllLines() {
   tensorflow::io::RandomAccessInputStream stream(file_.get());
   tensorflow::io::BufferedInputStream buf(&stream, 65536);
   string shape_string;
@@ -74,9 +73,9 @@
         ShapeUtil::HumanString(shape));
   }
 
-  auto result = absl::make_unique<Literal>(shape);
+  Literal result(shape);
   const float fill = std::numeric_limits<float>::quiet_NaN();
-  result->PopulateWithValue<float>(fill);
+  result.PopulateWithValue<float>(fill);
   std::vector<absl::string_view> pieces;
   std::vector<absl::string_view> coordinates;
   std::vector<int64> coordinate_values;
@@ -116,7 +115,7 @@
           "\"%s\"",
           shape.dimensions_size(), coordinate_values.size(), line);
     }
-    result->Set<float>(coordinate_values, value);
+    result.Set<float>(coordinate_values, value);
   }
   return std::move(result);
 }
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index b265640..c40b432 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -41,7 +41,7 @@
  public:
   // See class comment -- reads a file in its entirety (there must be only one
   // literal in the text file path provided).
-  static StatusOr<std::unique_ptr<Literal>> ReadPath(absl::string_view path);
+  static StatusOr<Literal> ReadPath(absl::string_view path);
 
  private:
   // Ownership of file is transferred.
@@ -49,7 +49,7 @@
 
   // Parses a shape string on the first line, followed by lines of values to the
   // end of the file.
-  StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+  StatusOr<Literal> ReadAllLines();
 
   // Owns the file being read
   std::unique_ptr<tensorflow::RandomAccessFile> file_;
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 92f9b4f..1fab4e3 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -42,16 +42,15 @@
       tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
           .ok());
 
-  std::unique_ptr<Literal> literal =
-      TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+  Literal literal = TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
   EXPECT_TRUE(
-      ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
-  EXPECT_EQ(42.5, literal->Get<float>({0, 0, 0}));
-  EXPECT_EQ(43.5, literal->Get<float>({0, 0, 1}));
-  EXPECT_EQ(44.5, literal->Get<float>({0, 0, 2}));
-  EXPECT_EQ(45.5, literal->Get<float>({0, 1, 0}));
-  EXPECT_EQ(46.5, literal->Get<float>({0, 1, 1}));
-  EXPECT_EQ(47.5, literal->Get<float>({0, 1, 2}));
+      ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal.shape()));
+  EXPECT_EQ(42.5, literal.Get<float>({0, 0, 0}));
+  EXPECT_EQ(43.5, literal.Get<float>({0, 0, 1}));
+  EXPECT_EQ(44.5, literal.Get<float>({0, 0, 2}));
+  EXPECT_EQ(45.5, literal.Get<float>({0, 1, 0}));
+  EXPECT_EQ(46.5, literal.Get<float>({0, 1, 1}));
+  EXPECT_EQ(47.5, literal.Get<float>({0, 1, 2}));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 4ea02fa..5cbaf2f 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -37,7 +37,7 @@
   });
   string path =
       tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
-  ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+  ASSERT_IS_OK(TextLiteralWriter::WriteToPath(literal, path));
   string contents;
   TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
                                            &contents));
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index 23ce1d2..0c3ec59 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -67,8 +67,8 @@
     floats.push_back(value);
   }
 
-  absl::string_view content(absl::bit_cast<const char*>(floats.data()),
-                            floats.size() * sizeof(float));
+  tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()),
+                                  floats.size() * sizeof(float));
   TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
                                             output_file, content));
   return 0;
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index ba814af..0c41f22 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -121,11 +121,10 @@
     }
   } else {  // use recorded data if available
     for (const auto& proto : module.arguments()) {
-      TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
-                          Literal::CreateFromProto(proto));
+      TF_ASSIGN_OR_RETURN(Literal literal, Literal::CreateFromProto(proto));
       TF_ASSIGN_OR_RETURN(
           ScopedShapedBuffer data,
-          client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+          client->LiteralToShapedBuffer(literal, /*device_ordinal=*/0));
       scoped_shaped_buffer_arguments.push_back(std::move(data));
     }
     for (const auto& argument : scoped_shaped_buffer_arguments) {
@@ -161,12 +160,12 @@
   // --generate_fake_infeed is passed and there exists an infeed operation in
   // the HloSnapshot.
   absl::optional<tensorflow::thread::ThreadPool> pool;
-  std::unique_ptr<Literal> data;
+  Literal data;
   if (provide_infeed) {
     data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie();
   }
   auto transfer_infeed = [&data, client]() {
-    TF_CHECK_OK(client->TransferToInfeed(*data));
+    TF_CHECK_OK(client->TransferToInfeed(data));
   };
   if (provide_infeed) {
     pool.emplace(tensorflow::Env::Default(), "infeed",
@@ -214,9 +213,9 @@
               << "s: " << module.hlo().hlo_module().name();
   }
 
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+  TF_ASSIGN_OR_RETURN(Literal result_literal,
                       client->ShapedBufferToLiteral(*result));
-  return std::move(*result_literal);
+  return result_literal;
 }
 
 StatusOr<HloSnapshot> ParseInputFile(const string& filename,
@@ -305,11 +304,11 @@
               result.ToString().c_str());
       auto& snapshot = snapshots[i];
       if (snapshot.has_result()) {
-        std::unique_ptr<Literal> literal =
+        Literal literal =
             Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
         fprintf(stdout, "was %s:%s\n",
                 ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
-                literal->ToString().c_str());
+                literal.ToString().c_str());
       }
     }
   }
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index 5190919..4f8852f 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -40,8 +40,8 @@
   xla::LiteralProto literal_proto;
   TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
                                           &literal_proto));
-  std::unique_ptr<xla::Literal> literal =
+  xla::Literal literal =
       xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie();
   LOG(INFO) << "literal: " << literal_proto.ShortDebugString();
-  fprintf(stderr, "%s\n", literal->ToString().c_str());
+  fprintf(stderr, "%s\n", literal.ToString().c_str());
 }
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 48c8374..4b5c276 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -36,16 +36,16 @@
     LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
   }
 
-  std::unique_ptr<xla::Literal> literal =
+  xla::Literal literal =
       xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
 
-  LOG(INFO) << "literal: " << *literal;
-  fprintf(stderr, "%s\n", literal->ToString().c_str());
-  if (literal->shape().element_type() == xla::F32) {
-    float min = *std::min_element(literal->data<float>().begin(),
-                                  literal->data<float>().end());
-    float max = *std::max_element(literal->data<float>().begin(),
-                                  literal->data<float>().end());
+  LOG(INFO) << "literal: " << literal;
+  fprintf(stderr, "%s\n", literal.ToString().c_str());
+  if (literal.shape().element_type() == xla::F32) {
+    float min = *std::min_element(literal.data<float>().begin(),
+                                  literal.data<float>().end());
+    float max = *std::max_element(literal.data<float>().begin(),
+                                  literal.data<float>().end());
     fprintf(stderr, "min: %a=%f\n", min, min);
     fprintf(stderr, "max: %a=%f\n", max, max);
   }
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 8e43f27..73b3589 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -351,6 +351,7 @@
 message LiteralProto {
   Shape shape = 1;
   repeated bool preds = 2;
+  bytes s8s = 15;
   bytes u8s = 3;
   repeated int32 s32s = 4;
   repeated int64 s64s = 5;
@@ -364,7 +365,7 @@
   bytes f16s = 11;
   bytes bf16s = 13;
   repeated int64 sparse_indices = 14;
-  // Next = 15
+  // Next = 16
 }
 
 message WindowDimension {
@@ -580,7 +581,7 @@
 
 // Used to indicate the precision configuration. It has backend specific
 // meaning.
-message PrecisionConfigProto {
+message PrecisionConfig {
   enum Precision {
     DEFAULT = 0;
     HIGH = 1;
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
index efbe980..2ff9791 100644
--- a/tensorflow/compiler/xrt/BUILD
+++ b/tensorflow/compiler/xrt/BUILD
@@ -56,6 +56,7 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/stream_executor",
         "@com_google_absl//absl/memory",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
     ],
 )
diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD
index 68ba17a..9e3d245 100644
--- a/tensorflow/compiler/xrt/kernels/BUILD
+++ b/tensorflow/compiler/xrt/kernels/BUILD
@@ -46,19 +46,15 @@
     deps = [
         ":xrt_state_ops",
         "//tensorflow/compiler/tf2xla:xla_compiler",
-        "//tensorflow/compiler/xla:literal",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:client_library",
-        "//tensorflow/compiler/xla/client:compile_only_client",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client:xla_computation",
-        "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
         "//tensorflow/compiler/xla/service:compiler",
         "//tensorflow/compiler/xla/service:computation_placer",
-        "//tensorflow/compiler/xla/service:hlo_proto",
         "//tensorflow/compiler/xrt:xrt_proto",
         "//tensorflow/compiler/xrt:xrt_utils",
         "//tensorflow/core:core_cpu_internal",
@@ -67,6 +63,7 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/stream_executor:stream_executor_headers_lib",
+        "@com_google_absl//absl/strings",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
index 5cf2bc8..1d4f8d9 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -22,6 +22,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -40,7 +41,6 @@
 #include "tensorflow/core/lib/core/refcount.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/strings/proto_serialization.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/fingerprint.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -70,7 +70,7 @@
   string serialized;
   TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
   uint64 fingerprint = Fingerprint64(serialized);
-  *key = strings::StrCat(fingerprint);
+  *key = absl::StrCat(fingerprint);
   return Status::OK();
 }
 
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
index 478c966..54b0655 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -49,7 +49,7 @@
   // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
   // OpKernel::Compute method.
   static Status MakeLiteral(const xla::LiteralProto& proto,
-                            std::unique_ptr<xla::Literal>* literal) {
+                            xla::Literal* literal) {
     TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
     return Status::OK();
   }
@@ -173,7 +173,7 @@
         errors::InvalidArgument(
             "Unable to parse allocation input to XLAAllocation"));
 
-    std::unique_ptr<xla::Literal> literal;
+    xla::Literal literal;
     OP_REQUIRES_OK(
         ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
 
@@ -189,7 +189,7 @@
 
     XRTTupleAllocation* allocation;
     OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
-                            *literal, device_ref.backend(),
+                            literal, device_ref.backend(),
                             device_ref.device_ordinal(), &allocation));
 
     // Intern takes ownership of our reference to allocation.
@@ -381,11 +381,11 @@
     OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
                             ctx, allocation->device_ordinal(), &device_ref));
 
-    std::unique_ptr<xla::Literal> literal;
+    xla::Literal literal;
     OP_REQUIRES_OK(
         ctx, allocation->ToLiteral(device_ref.backend(),
                                    device_ref.device_ordinal(), &literal));
-    xla::LiteralProto literal_proto = literal->ToProto();
+    xla::LiteralProto literal_proto = literal.ToProto();
 
     Tensor output(DT_STRING, TensorShape({}));
     literal_proto.SerializeToString(&output.scalar<string>()());
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 5b8516b..2952feb 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -52,44 +52,44 @@
 xla::LiteralProto TwoElementTuple() {
   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
-  auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
-  return tuple->ToProto();
+  auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+  return tuple.ToProto();
 }
 
 xla::LiteralProto ScalarLiteral() {
   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
-  return scalar->ToProto();
+  return scalar.ToProto();
 }
 
 xla::LiteralProto NestedTuple() {
   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
-  auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+  auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
-  auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
-  return nested->ToProto();
+  auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
+  return nested.ToProto();
 }
 
 xla::LiteralProto MakeTuple0() {
   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
-  auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
-  auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()});
-  auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()});
-  return nested1->ToProto();
+  auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
+  auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
+  auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
+  return nested1.ToProto();
 }
 
-xla::LiteralProto FloatVector(gtl::ArraySlice<float> v) {
+xla::LiteralProto FloatVector(absl::Span<const float> v) {
   auto array = xla::LiteralUtil::CreateR1<float>(v);
-  return array->ToProto();
+  return array.ToProto();
 }
 
 bool CompareLiteralProtos(const xla::LiteralProto& a,
                           const xla::LiteralProto& b) {
   auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
-  bool equal = *l_a == *l_b;
+  bool equal = l_a == l_b;
   if (!equal) {
     LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
               << " != " << b.DebugString();
@@ -100,7 +100,7 @@
 bool CompareLiteralToLiteralProto(const xla::Literal& a,
                                   const xla::LiteralProto& b) {
   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
-  bool equal = a == *l_b;
+  bool equal = a == l_b;
   if (!equal) {
     LOG(INFO) << "Literal and LiteralProto don't match "
               << a.ToProto().DebugString() << " != " << b.DebugString();
@@ -211,7 +211,7 @@
   TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
 
   auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
-  auto base_elements = base_literal->DecomposeTuple();
+  auto base_elements = base_literal.DecomposeTuple();
   auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
   xla::LiteralProto response_0;
   EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
@@ -343,7 +343,7 @@
   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
 
   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
-  EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+  EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 }
 
 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
@@ -392,8 +392,8 @@
   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
 
   auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
-  auto expected = xla::LiteralUtil::MakeTuple({sum.get()});
-  EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+  auto expected = xla::LiteralUtil::MakeTuple({&sum});
+  EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 911ac9a..d05a1e7 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -24,6 +24,7 @@
 #include <utility>
 
 #include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/backend.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -32,7 +33,6 @@
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/types.h"
 #include "tensorflow/stream_executor/stream_executor.h"
 
@@ -174,7 +174,7 @@
 }
 
 Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
-                                     std::unique_ptr<xla::Literal>* literal) {
+                                     xla::Literal* literal) {
   auto transfer_manager = backend->transfer_manager();
   TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
   TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
@@ -201,14 +201,14 @@
 
 /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
                                              XRTTupleAllocation** allocation) {
-  string key_string = strings::StrCat(key);
+  string key_string = absl::StrCat(key);
   TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
   return Status::OK();
 }
 
 /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
                                                                 int64 key) {
-  string key_string = strings::StrCat(key);
+  string key_string = absl::StrCat(key);
   return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
 }
 
@@ -410,7 +410,7 @@
 
 Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
   *key = get_uid();
-  string key_string = strings::StrCat(*key);
+  string key_string = absl::StrCat(*key);
   return rm->Create(kTupleContainer, key_string, this);
 }
 
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
index 4270568..73b5584 100644
--- a/tensorflow/compiler/xrt/xrt_state.h
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -135,7 +135,7 @@
 
   // Copies the allocation from device to host and returns it in literal.
   Status ToLiteral(xla::Backend* backend, int device_ordinal,
-                   std::unique_ptr<xla::Literal>* literal);
+                   xla::Literal* literal);
 
   // True if none of the buffers in the allocation are aliased by any other live
   // handle.
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 6698380..d98a249 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -20,13 +20,7 @@
     ),
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
-    deps = if_not_windows([
-        # TODO(aaroey): tensorrt dependency has to appear before tflite so the
-        # build can resolve its flatbuffers symbols within the tensorrt library.
-        # This is an issue with the tensorrt static library and will be fixed by
-        # the next tensorrt release, so fix the order here after that.
-        "//tensorflow/contrib/tensorrt:init_py",  # doesn't compile on windows
-    ]) + [
+    deps = [
         "//tensorflow/contrib/all_reduce",
         "//tensorflow/contrib/batching:batch_py",
         "//tensorflow/contrib/bayesflow:bayesflow_py",
@@ -135,6 +129,7 @@
     ]) + if_not_windows([
         "//tensorflow/contrib/bigtable",  # depends on bigtable
         "//tensorflow/contrib/cloud:cloud_py",  # doesn't compile on Windows
+        "//tensorflow/contrib/tensorrt:init_py",  # doesn't compile on windows
         "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
     ]),
 )
@@ -171,7 +166,9 @@
             "//tensorflow/contrib/kinesis:dataset_kernels",
         ],
         "//conditions:default": [],
-    }),
+    }) + if_not_windows([
+        "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
+    ]),
 )
 
 cc_library(
@@ -208,5 +205,7 @@
             "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
         ],
         "//conditions:default": [],
-    }),
+    }) + if_not_windows([
+        "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
+    ]),
 )
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 5f477a7..9478e42 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,6 +21,14 @@
 
 import os
 
+from tensorflow.python.tools import component_api_helper
+component_api_helper.package_hook(
+    parent_package_str=(
+        "tensorflow.contrib"),
+    child_package_str=(
+        "tensorflow_estimator.contrib.estimator"))
+del component_api_helper
+
 # Add projects here, they will show up under tf.contrib.
 from tensorflow.contrib import autograph
 from tensorflow.contrib import batching
diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD
index ad700ac..e37ad7a 100644
--- a/tensorflow/contrib/autograph/BUILD
+++ b/tensorflow/contrib/autograph/BUILD
@@ -21,11 +21,9 @@
     ],
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
+    # This module is kept for backward compatibility only. To depend on AutoGraph,
+    # use //third_party/tensorflow/python/autograph instead.
     deps = [
-        "//tensorflow/contrib/autograph/impl",
-        "//tensorflow/contrib/autograph/lang",
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/utils",
-        "//tensorflow/python:util",
+        "//tensorflow/python/autograph",
     ],
 )
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index cc54da4..6ea2db7 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -1,5 +1,12 @@
 # AutoGraph
 
+**NOTE: As tensorflow.contrib is being
+[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
+moving into TensorFlow core.
+
+The new code location is `tensorflow/python/autograph`.
+**
+
 IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
 
 AutoGraph is a Python to TensorFlow compiler.
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index 26e7a4a..137bc59 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -12,57 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Autograph compiles Python code into equivalent TensorFlow code.
+"""This is the legacy module for AutoGraph, kept for backward compatibility.
 
-Equivalent here means that they have the same effect when executed.
+New users should instead use `tensorflow.python.autograph`.
 """
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-# TODO(mdan): Bring only the relevant symbols to the top level.
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core.errors import GraphConstructionError
-from tensorflow.contrib.autograph.core.errors import TfRuntimeError
-from tensorflow.contrib.autograph.core.errors import improved_errors
-from tensorflow.contrib.autograph.impl.api import RunMode
-from tensorflow.contrib.autograph.impl.api import convert
-from tensorflow.contrib.autograph.impl.api import converted_call
-from tensorflow.contrib.autograph.impl.api import do_not_convert
-from tensorflow.contrib.autograph.impl.api import to_code
-from tensorflow.contrib.autograph.impl.api import to_graph
-from tensorflow.contrib.autograph.lang.directives import set_element_type
-from tensorflow.contrib.autograph.lang.directives import set_loop_options
-from tensorflow.contrib.autograph.lang.special_functions import stack
-from tensorflow.contrib.autograph.lang.special_functions import tensor_list
-from tensorflow.contrib.autograph.pyct.transformer import AutographParseError
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
-    # Main API
-    'RunMode',
-    'convert',
-    'converted_call',
-    'do_not_convert',
-    'to_code',
-    'to_graph',
-    # Overloaded operators
-    'operators',
-    # Errors
-    'improved_errors',
-    'GraphConstructionError',
-    'TfRuntimeError',
-    # Python language "extensions"
-    'set_element_type',
-    'set_loop_options',
-    'stack',
-    'tensor_list',
-    # Exceptions
-    'AutographParseError',
-    # Utilities: to be removed
-    'utils',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
+from tensorflow.python.autograph import *  # pylint:disable=wildcard-import
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
deleted file mode 100644
index b26c522..0000000
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Handles builtins and other special functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
-
-
-class BuiltinFunctionTransformer(converter.Base):
-  """Handles builtin functions.
-
-  This transformer only covers functions that are translated into a
-  TF equivalent, like `len`.
-  """
-
-  def _convert_builtin(self, node):
-    template = """
-      ag__.utils.dynamic_builtin(func, args)
-    """
-    return templates.replace(template, func=node.func, args=node.args)[0].value
-
-  def _convert_print(self, node):
-    template = """
-      ag__.utils.dynamic_print(args)
-    """
-    return templates.replace(template, args=node.args)[0].value
-
-  def visit_Call(self, node):
-    self.generic_visit(node)
-    # TODO(mdan): This won't work if the function was hidden.
-    # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
-    if (isinstance(node.func, gast.Name) and
-        node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
-      return self._convert_builtin(node)
-    # Print needs to be handled separately because it can be read as statement.
-    if isinstance(node.func, gast.Name) and node.func.id == 'print':
-      return self._convert_print(node)
-    return node
-
-  def visit_Print(self, node):
-    self.generic_visit(node)
-    args = node.values
-    # Following is the case when calling print(a, b)
-    if len(args) == 1 and isinstance(args[0], gast.Tuple):
-      args = args[0].elts
-    template = """
-      fname(args)
-    """
-    function_call = templates.replace(template, fname='print', args=args)[0]
-    return self.visit(function_call)
-
-
-def transform(node, ctx):
-  return BuiltinFunctionTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
deleted file mode 100644
index 57b5f74..0000000
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility module that contains APIs usable in the generated code."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_print
-from tensorflow.contrib.autograph.utils.builtins import dynamic_range
-from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
-from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
-from tensorflow.contrib.autograph.utils.testing import fake_tf
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
deleted file mode 100644
index 4dd440e..0000000
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Builtin conversion utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import list_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-
-
-def dynamic_builtin(f, *args, **kwargs):
-  """Converts a builtin function call inline."""
-  if f is len:
-    return dynamic_len(*args, **kwargs)
-  if six.PY2 and f is xrange:
-    return dynamic_range(*args, **kwargs)
-  if f is range:
-    return dynamic_range(*args, **kwargs)
-  if f is int:
-    return dynamic_int(*args, **kwargs)
-  if f is float:
-    return dynamic_float(*args, **kwargs)
-  if f is abs:
-    return dynamic_abs(*args, **kwargs)
-
-  raise NotImplementedError(
-      'The "%s" builtin is not yet supported.' % f.__name__)
-
-
-def dynamic_len(list_or_tensor):
-  """Implementation of len using dynamic dispatch."""
-  if _is_tensor_list(list_or_tensor):
-    return list_ops.tensor_list_length(list_or_tensor)
-  elif tensor_util.is_tensor(list_or_tensor):
-    shape = list_or_tensor.shape
-    if not shape.ndims:
-      raise ValueError(
-          'len requires non-zero rank for tensor "%s"' % list_or_tensor)
-    return array_ops.shape(list_or_tensor)[0]
-  return len(list_or_tensor)
-
-
-def _is_tensor_list(list_or_tensor):
-  return (tensor_util.is_tensor(list_or_tensor)
-          and list_or_tensor.dtype == dtypes.variant)
-
-
-def dynamic_int(num_or_tensor, **kwargs):
-  """Implementation of int() using dynamic dispatch."""
-  if tensor_util.is_tensor(num_or_tensor):
-    return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
-  return int(num_or_tensor)
-
-
-def dynamic_float(num_or_tensor, **kwargs):
-  """Implementation of float() using dynamic dispatch."""
-  if tensor_util.is_tensor(num_or_tensor):
-    return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
-  return float(num_or_tensor)
-
-
-def dynamic_abs(num_or_tensor, **kwargs):
-  if tensor_util.is_tensor(num_or_tensor):
-    return math_ops.abs(num_or_tensor, **kwargs)
-  else:
-    return abs(num_or_tensor, **kwargs)
-
-
-def dynamic_range(start_or_stop, stop=None, step=None):
-  """Implementation of range using dynamic dispatch."""
-  if type_check.is_tensor(start_or_stop, stop, step):
-    if step is not None:
-      return math_ops.range(start_or_stop, stop, step)
-    if stop is not None:
-      return math_ops.range(start_or_stop, stop)
-    return math_ops.range(start_or_stop)
-
-  if step is not None:
-    return range(start_or_stop, stop, step)
-  elif stop is not None:
-    return range(start_or_stop, stop)
-  return range(start_or_stop)
-
-
-def is_tf_print_compatible(value):
-  # TODO(mdan): Enable once we can reliably test this.
-  # This is currently disabled because we can't capture the output of
-  # op kernels from Python.
-  del value
-  return False
-
-
-def dynamic_print(*values):
-  """Implementation of print using dynamic dispatch.
-
-  The function attempts to use tf.Print if all the values are compatible.
-  Otherwise, it will fall back to py_func.
-
-  Args:
-    *values: values to print
-  Returns:
-    A dummy value indicating the print completed. If tf.
-  """
-
-  if all(map(is_tf_print_compatible, values)):
-    return logging_ops.Print(1, values)
-
-  def print_wrapper(*vals):
-    if six.PY3:
-      # TensorFlow doesn't seem to generate Unicode when passing strings to
-      # py_func. This causes the print to add a "b'" wrapper to the output,
-      # which is probably never what you want.
-      vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
-    print(*vals)
-    # The flush helps avoid garbled output in IPython.
-    sys.stdout.flush()
-
-  return py_func.wrap_py_func(
-      print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
deleted file mode 100644
index b1cd525..0000000
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for builtins module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
-
-class BuiltinsTest(test.TestCase):
-
-  def test_dynamic_len_tf_scalar(self):
-    a = constant_op.constant(1)
-
-    with self.assertRaisesRegexp(ValueError,
-                                 'len requires non-zero rank for tensor.*'):
-      with self.test_session() as sess:
-        sess.run(builtins.dynamic_builtin(len, a))
-
-  def test_dynamic_len_tf_array(self):
-    a = constant_op.constant([1, 2, 3])
-
-    with self.test_session() as sess:
-      self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
-
-  def test_dynamic_abs_tf_scalar(self):
-    a = constant_op.constant(-1)
-
-    with self.test_session() as sess:
-      self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
-
-  def test_dynamic_abs_tf_array(self):
-    a = constant_op.constant([-1, 2, -3])
-
-    with self.test_session() as sess:
-      self.assertListEqual([1, 2, 3],
-                           list(sess.run(builtins.dynamic_builtin(abs, a))))
-
-  def test_dynamic_abs_py_scalar(self):
-    a = -1
-    self.assertEqual(1, builtins.dynamic_builtin(abs, a))
-
-  def test_dynamic_len_tf_matrix(self):
-    a = constant_op.constant([[1, 2], [3, 4]])
-
-    with self.test_session() as sess:
-      self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
-
-  def test_dynamic_len_py_list(self):
-    a = [3] * 5
-
-    self.assertEqual(5, builtins.dynamic_builtin(len, a))
-
-  def test_dynamic_range_all_python(self):
-    self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
-    self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
-    self.assertListEqual(
-        list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
-
-  def test_dynamic_range_tf(self):
-    with self.test_session() as sess:
-      self.assertAllEqual(
-          sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
-          [0, 1, 2])
-      self.assertAllEqual(
-          sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
-          [1, 2])
-      self.assertAllEqual(
-          sess.run(
-              builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
-          [2, 1])
-
-  def test_dynamic_range_detection(self):
-    def range(x):  # pylint:disable=redefined-builtin
-      return x
-
-    # Functions that just have the names of builtins are rejected.
-    with self.assertRaises(NotImplementedError):
-      self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
-    if six.PY2:
-      self.assertListEqual(
-          list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
-    self.assertListEqual(
-        list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
-    self.assertListEqual(
-        list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
-
-  def test_casts(self):
-    i = constant_op.constant(2, dtype=dtypes.int32)
-    f = constant_op.constant(1.0, dtype=dtypes.float32)
-
-    self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
-    self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
-    self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
-    self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
-
-    self.assertEqual(builtins.dynamic_builtin(int, True), 1)
-    self.assertEqual(builtins.dynamic_builtin(int, False), 0)
-    self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
-    self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
-
-  def test_dynamic_print_tf(self):
-    try:
-      out_capturer = six.StringIO()
-      sys.stdout = out_capturer
-      with self.test_session() as sess:
-        sess.run(builtins.dynamic_print('test message', 1))
-        self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
-    finally:
-      sys.stdout = sys.__stdout__
-
-  def test_dynamic_print_complex(self):
-    try:
-      out_capturer = six.StringIO()
-      sys.stdout = out_capturer
-      with self.test_session() as sess:
-        sess.run(builtins.dynamic_print('test message', [1, 2]))
-        self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
-    finally:
-      sys.stdout = sys.__stdout__
-
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
deleted file mode 100644
index f72f8e9..0000000
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for multiple_dispatch."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.autograph.utils import multiple_dispatch
-from tensorflow.python.client.session import Session
-from tensorflow.python.framework.constant_op import constant
-from tensorflow.python.platform import test
-
-
-class MultipleDispatchTest(test.TestCase):
-
-  def test_dynamic_is_python(self):
-    a = np.eye(3)
-    also_a = a
-    not_actually_a = np.eye(3)
-    should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
-    should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
-    should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
-    should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
-    self.assertTrue(should_be_true1)
-    self.assertTrue(should_be_true2)
-    self.assertFalse(should_be_false1)
-    self.assertFalse(should_be_false2)
-
-  def test_dynamic_is_tf(self):
-    with Session().as_default():
-      a = constant([2.0])
-      also_a = a
-      not_actually_a = constant([2.0])
-      should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
-      should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
-      should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
-      should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
-      self.assertTrue(should_be_true1)
-      self.assertTrue(should_be_true2)
-      self.assertFalse(should_be_false1)
-      self.assertFalse(should_be_false2)
-
-  def test_run_cond_python(self):
-    true_fn = lambda: (2,)
-    false_fn = lambda: (3,)
-    self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2)
-    self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3)
-
-  def test_run_cond_tf(self):
-    true_fn = lambda: (constant(2),)
-    false_fn = lambda: (constant(3),)
-    with Session() as sess:
-      out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn)
-      self.assertEqual(sess.run(out), 2)
-      out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
-      self.assertEqual(sess.run(out), 3)
-
-
-if __name__ == '__main__':
-  test.main()
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index a25a641..6138d79 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -172,6 +172,11 @@
 REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
                         BigtableTableOp);
 
+}  // namespace
+
+namespace data {
+namespace {
+
 class ToBigtableOp : public AsyncOpKernel {
  public:
   explicit ToBigtableOp(OpKernelConstruction* ctx)
@@ -354,5 +359,6 @@
                         ToBigtableOp);
 
 }  // namespace
+}  // namespace data
 
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index a2a5df1..4652021 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -79,6 +79,8 @@
   ::google::cloud::bigtable::noex::Table table_;
 };
 
+namespace data {
+
 // BigtableReaderDatasetIterator is an abstract class for iterators from
 // datasets that are "readers" (source datasets, not transformation datasets)
 // that read from Bigtable.
@@ -138,6 +140,8 @@
   ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
 };
 
+}  // namespace data
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index bd32672..11f530e 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
@@ -226,4 +227,5 @@
                         BigtableLookupDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index a803fdc..5cab729 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
@@ -111,4 +112,5 @@
                         BigtablePrefixKeyDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 5cd0371..4dc4647 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
@@ -117,4 +118,5 @@
 REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
                         BigtableRangeKeyDatasetOp);
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 6928d94..736775b 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
@@ -205,4 +206,5 @@
     BigtableSampleKeyPairsDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index a759fb5..208b7b3 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
@@ -118,4 +119,5 @@
                         BigtableSampleKeysDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 78a920b..9407855 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class BigtableScanDatasetOp : public DatasetOpKernel {
@@ -224,4 +225,5 @@
                         BigtableScanDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce24..4c7a538 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@
                center_bias=True,
                use_core_libs=False,
                output_leaf_index=False,
-               override_global_step_value=None):
+               override_global_step_value=None,
+               num_quantiles=100):
     """Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
 
     Args:
@@ -94,6 +95,7 @@
         trees were trained), this parameter can be used to set the global step
         to a large value, making it look like that number of training steps ran.
         If None, no override of global step will happen.
+      num_quantiles: Number of quantiles to build for numeric feature values.
 
     Raises:
       ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@
             'logits_modifier_function': logits_modifier_function,
             'use_core_libs': use_core_libs,
             'output_leaf_index': output_leaf_index,
-            'override_global_step_value': override_global_step_value
+            'override_global_step_value': override_global_step_value,
+            'num_quantiles': num_quantiles,
         },
         model_dir=model_dir,
         config=config,
@@ -159,7 +162,8 @@
                center_bias=True,
                use_core_libs=False,
                output_leaf_index=False,
-               override_global_step_value=None):
+               override_global_step_value=None,
+               num_quantiles=100):
     """Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
 
     Args:
@@ -201,6 +205,7 @@
         trees were trained), this parameter can be used to set the global step
         to a large value, making it look like that number of training steps ran.
         If None, no override of global step will happen.
+      num_quantiles: Number of quantiles to build for numeric feature values.
     """
     head = head_lib.regression_head(
         label_name=label_name,
@@ -224,7 +229,8 @@
             'center_bias': center_bias,
             'use_core_libs': use_core_libs,
             'output_leaf_index': False,
-            'override_global_step_value': override_global_step_value
+            'override_global_step_value': override_global_step_value,
+            'num_quantiles': num_quantiles,
         },
         model_dir=model_dir,
         config=config,
@@ -251,7 +257,8 @@
                center_bias=True,
                use_core_libs=False,
                output_leaf_index=False,
-               override_global_step_value=None):
+               override_global_step_value=None,
+               num_quantiles=100):
     """Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
 
     Args:
@@ -289,6 +296,7 @@
         trees were trained), this parameter can be used to set the global step
         to a large value, making it look like that number of training steps ran.
         If None, no override of global step will happen.
+      num_quantiles: Number of quantiles to build for numeric feature values.
     """
     super(GradientBoostedDecisionTreeEstimator, self).__init__(
         model_fn=model.model_builder,
@@ -303,7 +311,8 @@
             'center_bias': center_bias,
             'use_core_libs': use_core_libs,
             'output_leaf_index': False,
-            'override_global_step_value': override_global_step_value
+            'override_global_step_value': override_global_step_value,
+            'num_quantiles': num_quantiles,
         },
         model_dir=model_dir,
         config=config,
@@ -329,7 +338,8 @@
                center_bias=False,
                use_core_libs=False,
                output_leaf_index=False,
-               override_global_step_value=None):
+               override_global_step_value=None,
+               num_quantiles=100):
     """Initializes a GradientBoostedDecisionTreeRanker instance.
 
     This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@
         trees were trained), this parameter can be used to set the global step
         to a large value, making it look like that number of training steps ran.
         If None, no override of global step will happen.
+      num_quantiles: Number of quantiles to build for numeric feature values.
+
     Raises:
       ValueError: If learner_config is not valid.
     """
@@ -395,7 +407,8 @@
             'use_core_libs': use_core_libs,
             'output_leaf_index': output_leaf_index,
             'ranking_model_pair_keys': ranking_model_pair_keys,
-            'override_global_step_value': override_global_step_value
+            'override_global_step_value': override_global_step_value,
+            'num_quantiles': num_quantiles,
         },
         model_dir=model_dir,
         config=config,
@@ -444,7 +457,8 @@
                feature_engineering_fn=None,
                logits_modifier_function=None,
                center_bias=True,
-               output_leaf_index=False):
+               output_leaf_index=False,
+               num_quantiles=100):
     """Initializes a core version of GradientBoostedDecisionTreeEstimator.
 
     Args:
@@ -474,6 +488,7 @@
         for example_prediction_result in result_dict:
           # access leaf index list by example_prediction_result["leaf_index"]
           # which contains one leaf index per tree
+      num_quantiles: Number of quantiles to build for numeric feature values.
     """
 
     def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@
               'logits_modifier_function': logits_modifier_function,
               'use_core_libs': True,
               'output_leaf_index': output_leaf_index,
-              'override_global_step_value': None
+              'override_global_step_value': None,
+              'num_quantiles': num_quantiles,
           },
           output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
 
@@ -517,7 +533,8 @@
                label_keys=None,
                logits_modifier_function=None,
                center_bias=False,
-               output_leaf_index=False):
+               output_leaf_index=False,
+               num_quantiles=100):
     """Initializes a GradientBoostedDecisionTreeRanker instance.
 
     This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@
         for result_dict in result_iter:
           # access leaf index list by result_dict["leaf_index"]
           # which contains one leaf index per tree
+      num_quantiles: Number of quantiles to build for numeric feature values.
 
     Raises:
       ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@
               'use_core_libs': True,
               'output_leaf_index': output_leaf_index,
               'ranking_model_pair_keys': ranking_model_pair_keys,
-              'override_global_step_value': None
+              'override_global_step_value': None,
+              'num_quantiles': num_quantiles,
           },
           output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
 
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3..a6e4228 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@
   logits_modifier_function = params["logits_modifier_function"]
   output_leaf_index = params["output_leaf_index"]
   override_global_step_value = params.get("override_global_step_value", None)
+  num_quantiles = params["num_quantiles"]
 
   if features is None:
     raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@
       logits_dimension=head.logits_dimension,
       features=training_features,
       use_core_columns=use_core_libs,
-      output_leaf_index=output_leaf_index)
+      output_leaf_index=output_leaf_index,
+      num_quantiles=num_quantiles)
   with ops.name_scope("gbdt", "gbdt_optimizer"):
     predictions_dict = gbdt_model.predict(mode)
     logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@
   output_leaf_index = params["output_leaf_index"]
   ranking_model_pair_keys = params["ranking_model_pair_keys"]
   override_global_step_value = params.get("override_global_step_value", None)
+  num_quantiles = params["num_quantiles"]
 
   if features is None:
     raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@
       logits_dimension=head.logits_dimension,
       features=main_features,
       use_core_columns=use_core_libs,
-      output_leaf_index=output_leaf_index)
+      output_leaf_index=output_leaf_index,
+      num_quantiles=num_quantiles)
 
   with ops.name_scope("gbdt", "gbdt_optimizer"):
     # Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 1375fdd..606da66 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -296,8 +296,9 @@
             int64 start, int64 end) {
           for (int resource_handle_idx = start; resource_handle_idx < end;
                ++resource_handle_idx) {
-            ResourceHandle handle = resource_handle_list[resource_handle_idx]
-                                        .flat<ResourceHandle>()(0);
+            const ResourceHandle& handle =
+                resource_handle_list[resource_handle_idx]
+                    .flat<ResourceHandle>()(0);
             QuantileStreamResource* streams_resource;
             // Create a reference to the underlying resource using the handle.
             OP_REQUIRES_OK(context,
@@ -709,8 +710,9 @@
          &buckets_list, stamp_token](int64 start, int64 end) {
           for (int resource_handle_idx = start; resource_handle_idx < end;
                ++resource_handle_idx) {
-            ResourceHandle handle = resource_handle_list[resource_handle_idx]
-                                        .flat<ResourceHandle>()(0);
+            const ResourceHandle& handle =
+                resource_handle_list[resource_handle_idx]
+                    .flat<ResourceHandle>()(0);
             QuantileStreamResource* streams_resource;
             OP_REQUIRES_OK(context,
                            LookupResource(context, handle, &streams_resource));
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 3b28ed7..51e0c2e 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -862,6 +862,15 @@
       auto* equality_split = split_info.mutable_split_node()
                                  ->mutable_categorical_id_binary_split();
       equality_split->set_feature_column(state->feature_column_group_id());
+      CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id)
+          << "Unexpected feature ID selected. "
+          << "Start feature ID: [" << start_index << "] "
+          << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1)
+          << "\nBest feature ID: [" << best_feature_idx << "] "
+          << feature_ids(best_feature_idx, 0) << ", "
+          << feature_ids(best_feature_idx, 1)
+          << "\nPartition IDS: " << partition_ids(start_index) << "  "
+          << partition_ids(best_feature_idx);
       equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
       auto* left_child = split_info.mutable_left_child();
       auto* right_child = split_info.mutable_right_child();
diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
index 90a0655..e446c41 100644
--- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
@@ -448,8 +448,9 @@
          stamp_token](int64 start, int64 end) {
           for (int resource_handle_idx = start; resource_handle_idx < end;
                ++resource_handle_idx) {
-            ResourceHandle handle = resource_handle_list[resource_handle_idx]
-                                        .flat<ResourceHandle>()(0);
+            const ResourceHandle& handle =
+                resource_handle_list[resource_handle_idx]
+                    .flat<ResourceHandle>()(0);
 
             StatsAccumulatorScalarResource* accumulator_resource;
             OP_REQUIRES_OK(context, LookupResource(context, handle,
@@ -512,8 +513,9 @@
          stamp_token](int64 start, int64 end) {
           for (int resource_handle_idx = start; resource_handle_idx < end;
                ++resource_handle_idx) {
-            ResourceHandle handle = resource_handle_list[resource_handle_idx]
-                                        .flat<ResourceHandle>()(0);
+            const ResourceHandle& handle =
+                resource_handle_list[resource_handle_idx]
+                    .flat<ResourceHandle>()(0);
 
             StatsAccumulatorTensorResource* accumulator_resource;
             OP_REQUIRES_OK(context, LookupResource(context, handle,
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index e640717..4da2529 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -29,7 +29,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 
-_BIAS_FEATURE_ID = -1
+_BIAS_FEATURE_ID = int(dtypes.int64.min)
 
 
 class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
@@ -141,11 +141,18 @@
       # The bias is computed on gradients and hessians (and not
       # filtered_gradients) which have exactly one value per example, so we
       # don't double count a gradient in multivalent columns.
+      # Since unsorted_segment_sum can be numerically unstable, use 64bit
+      # operation.
+      gradients64 = math_ops.cast(gradients, dtypes.float64)
+      hessians64 = math_ops.cast(hessians, dtypes.float64)
       per_partition_gradients = math_ops.unsorted_segment_sum(
-          gradients, mapped_partitions, array_ops.size(unique_partitions))
+          gradients64, mapped_partitions, array_ops.size(unique_partitions))
       per_partition_hessians = math_ops.unsorted_segment_sum(
-          hessians, mapped_partitions, array_ops.size(unique_partitions))
-
+          hessians64, mapped_partitions, array_ops.size(unique_partitions))
+      per_partition_gradients = math_ops.cast(per_partition_gradients,
+                                              dtypes.float32)
+      per_partition_hessians = math_ops.cast(per_partition_hessians,
+                                             dtypes.float32)
       # Prepend a bias feature per partition that accumulates the stats for all
       # examples in that partition.
       # Bias is added to the stats even if there are no examples with values in
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index d9f03c3..94ea7bc 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -47,7 +47,7 @@
 class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
 
   def testGenerateFeatureSplitCandidates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Feature ID     |
       # i0      |  (0.2, 0.12)  | 0         | 1,2            |
@@ -281,7 +281,7 @@
         gains[0], 0.00001)
 
   def testGenerateFeatureSplitCandidatesSumReduction(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Feature ID     |
       # i0      |  (0.2, 0.12)  | 0         | 1,2            |
@@ -404,7 +404,7 @@
     self.assertEqual(1, split_node.feature_id)
 
   def testGenerateFeatureSplitCandidatesMulticlass(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Batch size is 4, 2 gradients per each instance.
       gradients = array_ops.constant(
           [[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2])
@@ -482,7 +482,7 @@
     self.assertEqual(1, split_node.feature_id)
 
   def testEmpty(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
       hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
       partition_ids = [0, 0, 0, 1]
@@ -530,7 +530,7 @@
     self.assertEqual(len(splits), 0)
 
   def testInactive(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
       hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
       partition_ids = [0, 0, 0, 1]
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 5532bd0..74b0ea6 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -50,7 +50,7 @@
 class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
 
   def testGenerateFeatureSplitCandidates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1              |
@@ -183,7 +183,7 @@
     self.assertAllClose(0.52, split_node.threshold, 0.00001)
 
   def testObliviousFeatureSplitGeneration(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 1         | 3              |
@@ -320,7 +320,7 @@
     self.assertEqual(2, oblivious_split_info.children_parent_id[1])
 
   def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1              |
@@ -458,7 +458,7 @@
     self.assertAllClose(0.52, split_node.threshold, 0.00001)
 
   def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
       # Batch size is 4, 2 gradients per each instance.
       gradients = array_ops.constant(
@@ -546,7 +546,7 @@
     self.assertAllClose(0.3, split_node.threshold, 1e-6)
 
   def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
       # Batch size is 4, 2 gradients per each instance.
       gradients = array_ops.constant(
@@ -633,7 +633,7 @@
     self.assertAllClose(0.3, split_node.threshold, 1e-6)
 
   def testGenerateFeatureSplitCandidatesInactive(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1              |
@@ -708,7 +708,7 @@
     self.assertEqual(len(splits), 0)
 
   def testGenerateFeatureSplitCandidatesWithTreeComplexity(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1              |
@@ -842,7 +842,7 @@
     self.assertAllClose(0.52, split_node.threshold, 0.00001)
 
   def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Dense Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1              |
@@ -951,7 +951,7 @@
 class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
 
   def testGenerateFeatureSplitCandidates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Sparse Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1               |
@@ -1074,7 +1074,7 @@
     self.assertAllClose(0.52, split_node.split.threshold)
 
   def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Sparse Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1               |
@@ -1207,7 +1207,7 @@
     self.assertAllClose(0.52, split_node.split.threshold)
 
   def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Batch is 4, 2 classes
       gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
                                       [4.0, -3]])
@@ -1302,7 +1302,7 @@
     self.assertAllClose(0.52, split_node.split.threshold)
 
   def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Batch is 4, 2 classes
       gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
                                       [4.0, -3]])
@@ -1397,7 +1397,7 @@
     self.assertAllClose(0.52, split_node.split.threshold)
 
   def testGenerateFeatureSplitCandidatesInactive(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The data looks like the following:
       # Example |  Gradients    | Partition | Sparse Quantile |
       # i0      |  (0.2, 0.12)  | 0         | 1               |
@@ -1475,7 +1475,7 @@
     self.assertEqual(len(splits), 0)
 
   def testEmpty(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
       # No values in this feature column in this mini-batch.
       values = array_ops.constant([], dtype=dtypes.float32)
@@ -1545,7 +1545,7 @@
 
   def testEmptyBuckets(self):
     """Test that reproduces the case when quantile buckets were empty."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_column = array_ops.sparse_placeholder(dtypes.float32)
 
       # We have two batches - at first, a sparse feature is empty.
@@ -1638,7 +1638,7 @@
     self.assertEqual(len(splits), 0)
 
   def testDegenerativeCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # One data example only, one leaf and thus one quantile bucket.The same
       # situation is when all examples have the same values. This case was
       # causing before a failure.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index 4278a30..46dfbde 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -331,7 +331,7 @@
       self.assertAllEqual([[], []], dropout_info.eval())
 
   def testObliviousEnsemble(self):
-    with self.test_session():
+    with self.cached_session():
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       # Bias tree.
       tree1 = tree_ensemble_config.trees.add()
@@ -1399,7 +1399,7 @@
       self.assertAllEqual([0, 0], result.eval())
 
   def testObliviousTreeNonFinalized(self):
-    with self.test_session():
+    with self.cached_session():
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       # Depth 3 tree.
       tree1 = tree_ensemble_config.trees.add()
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index b3e4c2e..86fd577 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -411,7 +411,7 @@
 
   def testGrowEmptyEnsembleObliviousCase(self):
     """Test growing an empty ensemble in the oblivious case."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -1620,7 +1620,7 @@
 
   def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
     """Test growing an existing ensemble with the last tree not finalized."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create existing ensemble with one root split
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
@@ -1810,7 +1810,7 @@
 
   def testGrowEnsembleWithEmptyNodesMiddleCase(self):
     """Test case: The middle existing leaves don't have examples."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
           """
@@ -2071,7 +2071,7 @@
 
   def testGrowEnsembleWithEmptyNodesBorderCase(self):
     """Test case: The first and last existing leaves don't have examples."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
       text_format.Merge(
           """
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b008c6e..c7eb249 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@
                feature_columns=None,
                use_core_columns=False,
                output_leaf_index=False,
-               output_leaf_index_modes=None):
+               output_leaf_index_modes=None,
+               num_quantiles=100):
     """Construct a new GradientBoostedDecisionTreeModel function.
 
     Args:
@@ -327,6 +328,7 @@
       output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
         dictates when leaf indices will be outputted. By default, leaf indices
         are only outputted in INFER mode.
+      num_quantiles: Number of quantiles to build for numeric feature values.
 
     Raises:
       ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@
     self._learner_config = learner_config
     self._feature_columns = feature_columns
     self._learner_config_serialized = learner_config.SerializeToString()
+    self._num_quantiles = num_quantiles
     self._max_tree_depth = variables.Variable(
         initial_value=self._learner_config.constraints.max_tree_depth)
     self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@
     loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
     weak_learner_type = constant_op.constant(
         self._learner_config.weak_learner_type)
-    epsilon = 0.01
-    num_quantiles = 100
+    num_quantiles = self._num_quantiles
+    epsilon = 1.0 / num_quantiles
     strategy_tensor = constant_op.constant(strategy)
     with ops.device(self._get_replica_device_setter(worker_device)):
       # Create handlers for dense float columns
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 150d734..94b7f4f 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -37,6 +37,7 @@
 
 Saving and restoring Python state:
 @@NumpyState
+@@PythonStateWrapper
 """
 
 from __future__ import absolute_import
@@ -45,6 +46,7 @@
 
 from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
 from tensorflow.contrib.checkpoint.python.python_state import NumpyState
+from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
 from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
 from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
 from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035..302d5cf 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import functools
+import six
 
 import numpy
 
@@ -101,7 +103,7 @@
     # TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
     # ndarrays checkpointable natively and using standard checkpointable list
     # tracking.
-    if isinstance(value, numpy.ndarray):
+    if isinstance(value, (numpy.ndarray, numpy.generic)):
       try:
         existing = super(NumpyState, self).__getattribute__(name)
         existing.array = value
@@ -127,7 +129,29 @@
     super(NumpyState, self).__setattr__(name, value)
 
 
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+  """Wraps a Python object for storage in an object-based checkpoint."""
+
+  @abc.abstractmethod
+  def _serialize(self):
+    """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+  @abc.abstractmethod
+  def _deserialize(self, string_value):
+    """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+  def _gather_saveables_for_checkpoint(self):
+    """Specify callbacks for saving and restoring `array`."""
+    return {
+        "py_state": functools.partial(
+            base.PythonStringStateSaveable,
+            state_callback=self._serialize,
+            restore_callback=self._deserialize)
+        }
+
+
+class _NumpyWrapper(PythonStateWrapper):
   """Wraps a NumPy array for storage in an object-based checkpoint."""
 
   def __init__(self, array):
@@ -139,7 +163,7 @@
     self.array = array
 
   def _serialize(self):
-    """Callback for `PythonStringStateSaveable` to serialize the array."""
+    """Callback to serialize the array."""
     string_file = BytesIO()
     try:
       numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@
     return serialized
 
   def _deserialize(self, string_value):
-    """Callback for `PythonStringStateSaveable` to deserialize the array."""
+    """Callback to deserialize the array."""
     string_file = BytesIO(string_value)
     try:
       self.array = numpy.load(string_file, allow_pickle=False)
     finally:
       string_file.close()
 
-  def _gather_saveables_for_checkpoint(self):
-    """Specify callbacks for saving and restoring `array`."""
-    return {
-        "array": functools.partial(
-            base.PythonStringStateSaveable,
-            state_callback=self._serialize,
-            restore_callback=self._deserialize)
-        }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
index 0439a47..4549435 100644
--- a/tensorflow/contrib/checkpoint/python/python_state_test.py
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -40,10 +40,13 @@
     save_state.a = numpy.ones([2, 2])
     save_state.b = numpy.ones([2, 2])
     save_state.b = numpy.zeros([2, 2])
+    save_state.c = numpy.int64(3)
     self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
     self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+    self.assertEqual(3, save_state.c)
     first_save_path = saver.save(prefix)
     save_state.a[1, 1] = 2.
+    save_state.c = numpy.int64(4)
     second_save_path = saver.save(prefix)
 
     load_state = python_state.NumpyState()
@@ -51,6 +54,7 @@
     loader.restore(first_save_path).initialize_or_restore()
     self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
     self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+    self.assertEqual(3, load_state.c)
     load_state.a[0, 0] = 42.
     self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
     loader.restore(first_save_path).run_restore_ops()
@@ -58,6 +62,7 @@
     loader.restore(second_save_path).run_restore_ops()
     self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
     self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+    self.assertEqual(4, load_state.c)
 
   def testNoGraphPollution(self):
     graph = ops.Graph()
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
index 493b3c6..11e177c 100644
--- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
@@ -197,7 +197,7 @@
   def _ReadAndCheckRowsUsingFeatures(self, num_rows):
     self.server.handler.num_rows = num_rows
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       feature_configs = {
           "int64_col":
               parsing_ops.FixedLenFeature(
@@ -254,7 +254,7 @@
     num_rows = 10
     self.server.handler.num_rows = num_rows
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = cloud.BigQueryReader(
           project_id=_PROJECT,
           dataset_id=_DATASET,
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
index 9b6c056..4f2ecbc 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -26,7 +26,7 @@
 
   def testSetBlockCache(self):
     cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       gcs_config_ops.configure_gcs(sess, block_cache=cfg)
 
   def testConfigureGcsHook(self):
@@ -36,7 +36,7 @@
              'type': 'authorized_user'}
     hook = gcs_config_ops.ConfigureGcsHook(credentials=creds)
     hook.begin()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None
       hook.after_create_session(sess, None)
 
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d..1056894f 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@
   def get_master(self):
     return self.master()
 
+  def get_job_name(self):
+    if self._shouldResolve():
+      return self._job_name
+
   def cluster_spec(self):
     """Returns a ClusterSpec object based on the latest TPU information.
 
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 0b79f71..789dab8 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -1,6 +1,10 @@
 TensorFlow CMake build
 ======================
 
+CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all
+platforms. For details, see the
+[TensorFlow install guide](https://www.tensorflow.org/install/).
+
 This directory contains CMake files for building TensorFlow on Microsoft
 Windows. [CMake](https://cmake.org) is a cross-platform tool that can
 generate build scripts for multiple build systems, including Microsoft
diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake
index ad2af01..1a147e9 100644
--- a/tensorflow/contrib/cmake/external/png.cmake
+++ b/tensorflow/contrib/cmake/external/png.cmake
@@ -13,6 +13,7 @@
 # limitations under the License.
 # ==============================================================================
 include (ExternalProject)
+include (GNUInstallDirs)
 
 set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive)
 set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz)
@@ -35,7 +36,7 @@
     endif()
   endif()
 else()
-  set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a)
+  set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a)
 endif()
 
 set(png_HEADERS
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
index 9b4bf62..3e25079 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -75,7 +75,7 @@
     multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
     expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       projected_multipliers1 = session.run(
           external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
               multipliers1, 1.0))
@@ -122,7 +122,7 @@
     ]
 
     multipliers = []
-    with self.test_session() as session:
+    with self.cached_session() as session:
       session.run(standard_ops.global_variables_initializer())
       while len(multipliers) < len(expected_multipliers):
         multipliers.append(session.run(optimizer.lagrange_multipliers))
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
index 34c4543..df0eced 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -97,7 +97,7 @@
     matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
     matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       eigenvector1 = session.run(
           swap_regret_optimizer._maximal_eigenvector_power_method(
               standard_ops.constant(matrix1)))
@@ -119,7 +119,7 @@
     expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
                                           [0.4, 0.3, 0.0]])
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       projected_matrix = session.run(
           swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
               matrix))
@@ -134,7 +134,7 @@
     expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
                                           [0.4, 0.5, 0.3]])
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       projected_matrix = session.run(
           standard_ops.exp(
               swap_regret_optimizer.
@@ -165,7 +165,7 @@
     ]
 
     matrices = []
-    with self.test_session() as session:
+    with self.cached_session() as session:
       session.run(standard_ops.global_variables_initializer())
       while len(matrices) < len(expected_matrices):
         matrices.append(session.run(optimizer.stochastic_matrix))
@@ -198,7 +198,7 @@
     ]
 
     matrices = []
-    with self.test_session() as session:
+    with self.cached_session() as session:
       session.run(standard_ops.global_variables_initializer())
       while len(matrices) < len(expected_matrices):
         matrices.append(session.run(optimizer.stochastic_matrix))
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 8cfe142..556d731 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -61,7 +61,7 @@
     for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
                                                      inputs_list,
                                                      tag_indices_list):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sequence_score = crf.crf_sequence_score(
             inputs=array_ops.expand_dims(inputs, 0),
             tag_indices=array_ops.expand_dims(tag_indices, 0),
@@ -96,7 +96,7 @@
     ]
     for sequence_lengths, inputs, tag_bitmap in zip(
         sequence_lengths_list, inputs_list, tag_bitmap_list):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sequence_score = crf.crf_multitag_sequence_score(
             inputs=array_ops.expand_dims(inputs, 0),
             tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
@@ -124,7 +124,7 @@
     for dtype in (np.int32, np.int64):
       tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
       sequence_lengths = np.array(3, dtype=np.int32)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         unary_score = crf.crf_unary_score(
             tag_indices=array_ops.expand_dims(tag_indices, 0),
             sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -140,7 +140,7 @@
     transition_params = np.array(
         [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
     sequence_lengths = np.array(3, dtype=np.int32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       binary_score = crf.crf_binary_score(
           tag_indices=array_ops.expand_dims(tag_indices, 0),
           sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -176,7 +176,7 @@
                                                      tag_indices_list):
       num_words = inputs.shape[0]
       num_tags = inputs.shape[1]
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         all_sequence_scores = []
 
         # Compare the dynamic program with brute force computation.
@@ -206,7 +206,7 @@
     """
     Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = constant_op.constant(np.ones([2, 10, 5],
                                             dtype=np.float32))
       transition_params = constant_op.constant(np.ones([5, 5],
@@ -226,7 +226,7 @@
     sequence_lengths = np.array(3, dtype=np.int32)
     num_words = inputs.shape[0]
     num_tags = inputs.shape[1]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       all_sequence_log_likelihoods = []
 
       # Make sure all probabilities sum to 1.
@@ -254,7 +254,7 @@
     num_words = inputs.shape[0]
     num_tags = inputs.shape[1]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       all_sequence_scores = []
       all_sequences = []
 
@@ -310,7 +310,7 @@
       num_words = inputs.shape[0]
       num_tags = inputs.shape[1]
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         all_sequence_scores = []
         all_sequences = []
 
@@ -351,7 +351,7 @@
     """
     Test that crf_decode works when sequence_length contains one or more zeros.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = constant_op.constant(np.ones([2, 10, 5],
                                             dtype=np.float32))
       transition_params = constant_op.constant(np.ones([5, 5],
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 5e6c152..baec238 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -26,6 +26,7 @@
 @@CheckpointInputPipelineHook
 @@CsvDataset
 @@LMDBDataset
+@@Optional
 @@RandomDataset
 @@Reducer
 @@SqlDataset
@@ -38,7 +39,7 @@
 @@copy_to_device
 @@dense_to_sparse_batch
 @@enumerate_dataset
-
+@@get_next_as_optional
 @@get_single_element
 @@group_by_reducer
 @@group_by_window
@@ -46,7 +47,6 @@
 @@make_batched_features_dataset
 @@make_csv_dataset
 @@make_saveable_from_iterator
-
 @@map_and_batch
 @@padded_batch_and_drop_remainder
 @@parallel_interleave
@@ -107,6 +107,8 @@
 from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
 from tensorflow.contrib.data.python.ops.unique import unique
 from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
 # pylint: enable=unused-import
 
 from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
index e36c9c06..c19a609 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/framework/tensor.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -150,4 +151,5 @@
                         AssertNextDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 0ba905b..21ec50f 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/core/lib/io/zlib_inputstream.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class CSVDatasetOp : public DatasetOpKernel {
@@ -48,6 +49,9 @@
     OP_REQUIRES_OK(ctx,
                    ctx->input_list("record_defaults", &record_defaults_list));
     for (int i = 0; i < record_defaults_list.size(); ++i) {
+      OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+                  errors::InvalidArgument(
+                      "Each record default should be at most rank 1"));
       OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
                   errors::InvalidArgument(
                       "There should only be 1 default per field but field ", i,
@@ -851,4 +855,5 @@
 REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
index ccf7ec1..a532162 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/lib/hash/hash.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -276,5 +276,5 @@
                         DirectedInterleaveDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
index 4718c1c..c3cb45d 100644
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
@@ -150,4 +151,5 @@
                         IdentityIndexedDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
index db24e60..beec344 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -137,5 +137,5 @@
                         IgnoreErrorsDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc
index c69564a..ced8ab0 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/core/lib/gtl/cleanup.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 Status VerifyTypesMatch(const DataTypeVector& expected,
@@ -367,6 +367,7 @@
                         MaterializeDatasetOp);
 REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
                         IndexedDatasetGet);
-}  // namespace
 
+}  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h
index 6149de8..7aa2d3f 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.h
@@ -19,6 +19,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 
 namespace tensorflow {
+namespace data {
 
 // TODO(saeta): Urgh, this is ugly.
 class MaterializedIndexedDataset {
@@ -112,6 +113,7 @@
 Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
                                           Tensor* tensor);
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
index 80f3999..d233c1f 100644
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
@@ -22,6 +22,7 @@
 #include "lmdb.h"  // NOLINT(build/include)
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class LMDBDatasetOp : public DatasetOpKernel {
@@ -212,4 +213,5 @@
 REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 725f893..078de71 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/core/util/device_name_utils.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 struct BufferElement {
@@ -1114,5 +1115,6 @@
     Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
     MultiDeviceIteratorFromStringHandleOp);
 
-}  // anonymous namespace
+}  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index ab58450..30fa97a 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/util/work_sharder.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class ThreadPoolResource : public ResourceBase {
@@ -214,4 +215,5 @@
                         ThreadPoolDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
index 6fbf5d2..57fc569 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/lib/hash/hash.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -219,5 +219,5 @@
                         UniqueDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index ae104d5..ad410e1 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -65,7 +65,13 @@
       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
       // `record_defaults` must be lists of scalars
       for (size_t i = 8; i < c->num_inputs(); ++i) {
-        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
+        shape_inference::ShapeHandle v;
+        TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+        if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+          return errors::InvalidArgument(
+              "Shape of a default must be a length-0 or length-1 vector, or a "
+              "scalar.");
+        }
       }
       return shape_inference::ScalarShape(c);
     });
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 34f594f..ba20283 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -72,12 +72,13 @@
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
-        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:parsing_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:session",
         "//tensorflow/python/data/ops:readers",
+        "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
     ],
 )
@@ -276,25 +277,13 @@
         "//tensorflow/python:check_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:data_flow_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:function",
+        "//tensorflow/python:functional_ops",
         "//tensorflow/python:math_ops",
-    ],
-)
-
-py_test(
-    name = "optimize_dataset_op_test",
-    size = "small",
-    srcs = ["optimize_dataset_op_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        "//tensorflow/contrib/data/python/ops:optimization",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:errors",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//third_party/py/numpy",
-        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/python:session",
     ],
 )
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 9d8e955..8e368bf 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
 
       for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
 
       for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize with an input tensor of incompatible rank.
       sess.run(init_op, feed_dict={input_tensor: [[1]]})
       with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@
     iterator = data.make_one_shot_iterator()
     op = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual((i,) * 3, sess.run(op))
 
@@ -168,7 +168,7 @@
     iterator = data.make_one_shot_iterator()
     op = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
 
@@ -187,7 +187,7 @@
     iterator = data.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         st_row = sess.run(next_element)
         self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@
     iterator = data.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         dense_elem, st_row = sess.run(next_element)
         self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@
     iterator = data.make_one_shot_iterator()
     op = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual(((i,),) * 3, sess.run(op))
 
@@ -250,7 +250,7 @@
     iterator = data.make_one_shot_iterator()
     op = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
                          sess.run(op))
@@ -266,7 +266,7 @@
     iterator = data.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(next_element)
 
@@ -284,7 +284,7 @@
     iterator = data.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Mismatch in the 0th dimension.
       sess.run(
           iterator.initializer,
@@ -319,7 +319,7 @@
 
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for test_batch_size in [1, 3, 7, 10]:
         sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
         num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(2):
         actual = sess.run(get_next)
@@ -374,7 +374,7 @@
 
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for test_batch_size in [1, 3, 7, 10]:
         sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
         num_batches = 7 // test_batch_size
@@ -428,10 +428,10 @@
     self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
 
   @parameterized.named_parameters(
-      ("default", None, None),
-      ("sequential_calls", 1, None),
-      ("parallel_calls", 2, None),
-      ("parallel_batches", None, 10),
+      ("Default", None, None),
+      ("SequentialCalls", 1, None),
+      ("ParallelCalls", 2, None),
+      ("ParallelBatches", None, 10),
   )
   def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
     """Test a dataset that maps a TF function across its input elements."""
@@ -461,7 +461,7 @@
     self.assertEqual([[None] + list(c.shape[1:]) for c in components],
                      [t.shape.as_list() for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Batch of a finite input, where the batch_size divides the
       # total number of elements.
       sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -505,8 +505,8 @@
         sess.run(init_op, feed_dict={count: 14, batch_size: 0})
 
   @parameterized.named_parameters(
-      ("even", False),
-      ("uneven", True),
+      ("Even", False),
+      ("Uneven", True),
   )
   def testMapAndBatchPartialBatch(self, drop_remainder):
     iterator = (
@@ -520,7 +520,7 @@
     else:
       self.assertEqual([None, 1], iterator.output_shapes.as_list())
     next_element = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
       self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
       if not drop_remainder:
@@ -535,7 +535,7 @@
                 .make_one_shot_iterator())
     self.assertEqual([None, 1], iterator.output_shapes.as_list())
     next_element = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
       self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
       self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@
     elements = []
     for _ in range(100):
       elements.append(iterator.get_next())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(5):
         got = sess.run(elements)
         got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@
     elements = []
     for _ in range(100):
       elements.append(iterator.get_next())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(4):
         got = sess.run(elements)
         got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(2):
         actual = sess.run(get_next)
@@ -614,7 +614,7 @@
         dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
         .make_initializable_iterator())
     init_op = iterator.initializer
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
         sess.run(init_op, feed_dict={batch_size: 14})
 
@@ -635,7 +635,7 @@
         .make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    "number of elements does not match"):
@@ -659,11 +659,18 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(3):
         sess.run(get_next)
 
-  @parameterized.parameters(0, 5, 10, 90, 95, 99)
+  @parameterized.named_parameters(
+      ("1", 0),
+      ("2", 5),
+      ("3", 10),
+      ("4", 90),
+      ("5", 95),
+      ("6", 99),
+  )
   def testMapAndBatchOutOfRangeError(self, threshold):
 
     def raising_py_fn(i):
@@ -679,7 +686,7 @@
                 batch_size=10)).make_one_shot_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(threshold // 10):
         self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
       if threshold % 10 != 0:
@@ -689,18 +696,18 @@
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @parameterized.parameters(
-      (False, dtypes.bool),
-      (-42, dtypes.int8),
-      (-42, dtypes.int16),
-      (-42, dtypes.int32),
-      (-42, dtypes.int64),
-      (42, dtypes.uint8),
-      (42, dtypes.uint16),
-      (42.0, dtypes.float16),
-      (42.0, dtypes.float32),
-      (42.0, dtypes.float64),
-      (b"hello", dtypes.string),
+  @parameterized.named_parameters(
+      ("1", False, dtypes.bool),
+      ("2", -42, dtypes.int8),
+      ("3", -42, dtypes.int16),
+      ("4", -42, dtypes.int32),
+      ("5", -42, dtypes.int64),
+      ("6", 42, dtypes.uint8),
+      ("7", 42, dtypes.uint16),
+      ("8", 42.0, dtypes.float16),
+      ("9", 42.0, dtypes.float32),
+      ("10", 42.0, dtypes.float64),
+      ("11", b"hello", dtypes.string),
   )
   def testMapAndBatchTypes(self, element, dtype):
     def gen():
@@ -711,7 +718,7 @@
 
     get_next = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(10):
         self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
 
@@ -777,7 +784,7 @@
     iterator = result.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(5):
         sess.run(get_next)
@@ -901,7 +908,7 @@
         .make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f..48971f2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@
   def checkResults(self, dataset, shapes, values):
     self.assertEqual(shapes, dataset.output_shapes)
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for expected in values:
         got = sess.run(get_next)
         self.assertEqual(got, expected)
@@ -129,7 +129,7 @@
       self.assertIs(None, dataset.output_shapes[1].ndims)
       iterator = dataset.make_one_shot_iterator()
       get_next = iterator.get_next()
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         x, y = sess.run(get_next)
         self.assertAllEqual([0] * (2**i), x)
         self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@
         (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
             grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x, y = sess.run(get_next)
       self.assertAllEqual(x, np.asarray([x for x in range(10)]))
       self.assertEqual(y, 45)
@@ -210,7 +210,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       counts = []
       with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       # The input is infinite, so this test demonstrates that:
       # 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
       self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -301,7 +301,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
@@ -329,7 +329,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       counts = []
       with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
 
       which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
 
       # Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
 
       # Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.OutOfRangeError):
         batches = 0
@@ -531,6 +531,45 @@
       self.assertEqual(batches, 15)
 
 
+def _element_length_fn(x, y=None):
+  del y
+  return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+  return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+  if sparse:
+    return {
+        "values": array,
+        "indices": [[i] for i in range(len(array))],
+        "dense_shape": (len(array),)
+    }
+  return array
+
+
+def _get_record_type(sparse):
+  if sparse:
+    return {
+        "values": dtypes.int64,
+        "indices": dtypes.int64,
+        "dense_shape": dtypes.int64
+    }
+  return dtypes.int32
+
+
+def _get_record_shape(sparse):
+  if sparse:
+    return {
+        "values": tensor_shape.TensorShape([None,]),
+        "indices": tensor_shape.TensorShape([None, 1]),
+        "dense_shape": tensor_shape.TensorShape([1,])
+    }
+  return tensor_shape.TensorShape([None])
+
+
 class BucketBySequenceLength(test.TestCase):
 
   def testBucket(self):
@@ -539,39 +578,58 @@
     batch_sizes = [10, 8, 4, 2]
     lengths = [8, 13, 25, 35]
 
-    def element_gen():
-      # Produce 1 batch for each bucket
-      elements = []
-      for batch_size, length in zip(batch_sizes, lengths):
-        for _ in range(batch_size):
-          elements.append([1] * length)
-      random.shuffle(elements)
-      for el in elements:
-        yield (el,)
+    def build_dataset(sparse):
+      def _generator():
+        # Produce 1 batch for each bucket
+        elements = []
+        for batch_size, length in zip(batch_sizes, lengths):
+          record_len = length - 1
+          for _ in range(batch_size):
+            elements.append([1] * record_len)
+            record_len = length
+        random.shuffle(elements)
+        for el in elements:
+          yield (_format_record(el, sparse),)
+      dataset = dataset_ops.Dataset.from_generator(
+          _generator,
+          (_get_record_type(sparse),),
+          (_get_record_shape(sparse),))
+      if sparse:
+        dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+      return dataset
 
-    element_len = lambda el: array_ops.shape(el)[0]
-    dataset = dataset_ops.Dataset.from_generator(
-        element_gen, (dtypes.int64,), ([None],)).apply(
-            grouping.bucket_by_sequence_length(
-                element_len, boundaries, batch_sizes))
-    batch, = dataset.make_one_shot_iterator().get_next()
+    def _test_bucket_by_padding(no_padding):
+      dataset = build_dataset(sparse=no_padding)
+      dataset = dataset.apply(
+          grouping.bucket_by_sequence_length(
+              _element_length_fn,
+              boundaries,
+              batch_sizes,
+              no_padding=no_padding))
+      batch, = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
-      batches = []
-      for _ in range(4):
-        batches.append(sess.run(batch))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(batch)
-    batch_sizes_val = []
-    lengths_val = []
-    for batch in batches:
-      batch_size = batch.shape[0]
-      length = batch.shape[1]
-      batch_sizes_val.append(batch_size)
-      lengths_val.append(length)
-    self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
-    self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
-    self.assertEqual(sorted(lengths), sorted(lengths_val))
+      with self.cached_session() as sess:
+        batches = []
+        for _ in range(4):
+          batches.append(sess.run(batch))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(batch)
+      batch_sizes_val = []
+      lengths_val = []
+      for batch in batches:
+        shape = batch.dense_shape if no_padding else batch.shape
+        batch_size = shape[0]
+        length = shape[1]
+        batch_sizes_val.append(batch_size)
+        lengths_val.append(length)
+        sum_check = batch.values.sum() if no_padding else batch.sum()
+        self.assertEqual(sum_check, batch_size * length - 1)
+      self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+      self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+      self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+    for no_padding in (True, False):
+      _test_bucket_by_padding(no_padding)
 
   def testPadToBoundary(self):
 
@@ -600,7 +658,7 @@
                 pad_to_bucket_boundary=True))
     batch, = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batches = []
       for _ in range(3):
         batches.append(sess.run(batch))
@@ -637,7 +695,7 @@
                 pad_to_bucket_boundary=True))
     batch, = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batches = []
       for _ in range(5):
         batches.append(sess.run(batch))
@@ -657,28 +715,108 @@
 
   def testTupleElements(self):
 
-    def elements_gen():
-      text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
-      label = [1, 2, 1, 2]
-      for x, y in zip(text, label):
-        yield (x, y)
+    def build_dataset(sparse):
+      def _generator():
+        text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+        label = [1, 2, 1, 2]
+        for x, y in zip(text, label):
+          yield (_format_record(x, sparse), y)
+      dataset = dataset_ops.Dataset.from_generator(
+          generator=_generator,
+          output_types=(_get_record_type(sparse), dtypes.int32),
+          output_shapes=(_get_record_shape(sparse),
+                         tensor_shape.TensorShape([])))
+      if sparse:
+        dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+      return dataset
 
-    def element_length_fn(x, y):
-      del y
-      return array_ops.shape(x)[0]
+    def _test_tuple_elements_by_padding(no_padding):
+      dataset = build_dataset(sparse=no_padding)
+      dataset = dataset.apply(grouping.bucket_by_sequence_length(
+          element_length_func=_element_length_fn,
+          bucket_batch_sizes=[2, 2, 2],
+          bucket_boundaries=[0, 8],
+          no_padding=no_padding))
+      shapes = dataset.output_shapes
+      self.assertEqual([None, None], shapes[0].as_list())
+      self.assertEqual([None], shapes[1].as_list())
 
-    dataset = dataset_ops.Dataset.from_generator(
-        generator=elements_gen,
-        output_shapes=(tensor_shape.TensorShape([None]),
-                       tensor_shape.TensorShape([])),
-        output_types=(dtypes.int32, dtypes.int32))
+    for no_padding in (True, False):
+      _test_tuple_elements_by_padding(no_padding)
+
+  def testBucketSparse(self):
+    """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+    Test runs on following dataset:
+      [
+        [0],
+        [0, 1],
+        [0, 1, 2]
+        ...
+        [0, ..., max_len - 1]
+      ]
+    Sequences are bucketed by length and batched with
+      `batch_size` < `bucket_size`.
+    """
+
+    min_len = 0
+    max_len = 100
+    batch_size = 7
+    bucket_size = 10
+
+    def _build_dataset():
+      input_data = [range(i+1) for i in range(min_len, max_len)]
+      def generator_fn():
+        for record in input_data:
+          yield _format_record(record, sparse=True)
+      dataset = dataset_ops.Dataset.from_generator(
+          generator=generator_fn,
+          output_types=_get_record_type(sparse=True))
+      dataset = dataset.map(_to_sparse_tensor)
+      return dataset
+
+    def _compute_expected_batches():
+      """Computes expected batch outputs and stores in a set."""
+      all_expected_sparse_tensors = set()
+      for bucket_start_len in range(min_len, max_len, bucket_size):
+        for batch_offset in range(0, bucket_size, batch_size):
+          batch_start_len = bucket_start_len + batch_offset
+          batch_end_len = min(batch_start_len + batch_size,
+                              bucket_start_len + bucket_size)
+          expected_indices = []
+          expected_values = []
+          for length in range(batch_start_len, batch_end_len):
+            for val in range(length + 1):
+              expected_indices.append((length - batch_start_len, val))
+              expected_values.append(val)
+          expected_sprs_tensor = (tuple(expected_indices),
+                                  tuple(expected_values))
+          all_expected_sparse_tensors.add(expected_sprs_tensor)
+      return all_expected_sparse_tensors
+
+    def _compute_batches(dataset):
+      """Computes actual batch outputs of dataset and stores in a set."""
+      batch = dataset.make_one_shot_iterator().get_next()
+      all_sparse_tensors = set()
+      with self.cached_session() as sess:
+        with self.assertRaises(errors.OutOfRangeError):
+          while True:
+            output = sess.run(batch)
+            sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+                           tuple(output.values))
+            all_sparse_tensors.add(sprs_tensor)
+      return all_sparse_tensors
+
+    dataset = _build_dataset()
+    boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
     dataset = dataset.apply(grouping.bucket_by_sequence_length(
-        element_length_func=element_length_fn,
-        bucket_batch_sizes=[2, 2, 2],
-        bucket_boundaries=[0, 8]))
-    shapes = dataset.output_shapes
-    self.assertEqual([None, None], shapes[0].as_list())
-    self.assertEqual([None], shapes[1].as_list())
+        _element_length_fn,
+        boundaries,
+        [batch_size] * (len(boundaries) + 1),
+        no_padding=True))
+    batches = _compute_batches(dataset)
+    expected_batches = _compute_expected_batches()
+    self.assertEqual(batches, expected_batches)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 63bffd0..f8e74e4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -31,38 +31,49 @@
 from tensorflow.contrib.data.python.ops import readers
 from tensorflow.python.client import session
 from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_in_graph_and_eager_modes
 class CsvDatasetOpTest(test.TestCase):
 
-  def _assert_datasets_equal(self, g, ds1, ds2):
+  def _get_next(self, dataset):
+    # Returns a no argument function whose result is fed to self.evaluate to
+    # yield the next element
+    it = dataset.make_one_shot_iterator()
+    if context.executing_eagerly():
+      return it.get_next
+    else:
+      get_next = it.get_next()
+      return lambda: get_next
+
+  def _assert_datasets_equal(self, ds1, ds2):
     assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
                                                     '%s') % (ds1.output_shapes,
                                                              ds2.output_shapes)
     assert ds1.output_types == ds2.output_types
     assert ds1.output_classes == ds2.output_classes
-    next1 = ds1.make_one_shot_iterator().get_next()
-    next2 = ds2.make_one_shot_iterator().get_next()
-    with self.session(graph=g) as sess:
-      # Run through datasets and check that outputs match, or errors match.
-      while True:
-        try:
-          op1 = sess.run(next1)
-        except (errors.OutOfRangeError, ValueError) as e:
-          # If op1 throws an exception, check that op2 throws same exception.
-          with self.assertRaises(type(e)):
-            sess.run(next2)
-          break
-        op2 = sess.run(next2)
-        self.assertAllEqual(op1, op2)
+    next1 = self._get_next(ds1)
+    next2 = self._get_next(ds2)
+    # Run through datasets and check that outputs match, or errors match.
+    while True:
+      try:
+        op1 = self.evaluate(next1())
+      except (errors.OutOfRangeError, ValueError) as e:
+        # If op1 throws an exception, check that op2 throws same exception.
+        with self.assertRaises(type(e)):
+          self.evaluate(next2())
+        break
+      op2 = self.evaluate(next2())
+      self.assertAllEqual(op1, op2)
 
   def _setup_files(self, inputs, linebreak='\n', compression_type=None):
     filenames = []
@@ -95,33 +106,32 @@
 
   def _test_by_comparison(self, inputs, **kwargs):
     """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
-    with ops.Graph().as_default() as g:
-      dataset_actual, dataset_expected = self._make_test_datasets(
-          inputs, **kwargs)
-      self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+    dataset_actual, dataset_expected = self._make_test_datasets(
+        inputs, **kwargs)
+    self._assert_datasets_equal(dataset_actual, dataset_expected)
 
   def _verify_output_or_err(self,
-                            sess,
                             dataset,
                             expected_output=None,
                             expected_err_re=None):
-    nxt = dataset.make_one_shot_iterator().get_next()
     if expected_err_re is None:
       # Verify that output is expected, without errors
+      nxt = self._get_next(dataset)
       expected_output = [[
           v.encode('utf-8') if isinstance(v, str) else v for v in op
       ] for op in expected_output]
       for value in expected_output:
-        op = sess.run(nxt)
+        op = self.evaluate(nxt())
         self.assertAllEqual(op, value)
       with self.assertRaises(errors.OutOfRangeError):
-        sess.run(nxt)
+        self.evaluate(nxt())
     else:
       # Verify that OpError is produced as expected
       with self.assertRaisesOpError(expected_err_re):
+        nxt = self._get_next(dataset)
         while True:
           try:
-            sess.run(nxt)
+            self.evaluate(nxt())
           except errors.OutOfRangeError:
             break
 
@@ -137,11 +147,8 @@
     # Convert str type because py3 tf strings are bytestrings
     filenames = self._setup_files(inputs, linebreak, compression_type)
     kwargs['compression_type'] = compression_type
-    with ops.Graph().as_default() as g:
-      with self.session(graph=g) as sess:
-        dataset = readers.CsvDataset(filenames, **kwargs)
-        self._verify_output_or_err(sess, dataset, expected_output,
-                                   expected_err_re)
+    dataset = readers.CsvDataset(filenames, **kwargs)
+    self._verify_output_or_err(dataset, expected_output, expected_err_re)
 
   def testCsvDataset_requiredFields(self):
     record_defaults = [[]] * 4
@@ -191,21 +198,17 @@
     record_defaults = [['']] * 3
     inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
     filenames = self._setup_files(inputs)
-    with ops.Graph().as_default() as g:
-      with self.session(graph=g) as sess:
-        dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
-        dataset = dataset.apply(error_ops.ignore_errors())
-        self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+    dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+    dataset = dataset.apply(error_ops.ignore_errors())
+    self._verify_output_or_err(dataset, [['e', 'f', 'g']])
 
   def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
     record_defaults = [['']] * 3
     inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
     filenames = self._setup_files(inputs)
-    with ops.Graph().as_default() as g:
-      with self.session(graph=g) as sess:
-        dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
-        dataset = dataset.apply(error_ops.ignore_errors())
-        self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+    dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+    dataset = dataset.apply(error_ops.ignore_errors())
+    self._verify_output_or_err(dataset, [['e', 'f', 'g']])
 
   def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
     record_defaults = [['']] * 3
@@ -351,10 +354,9 @@
     inputs = [['1,,3,4', '5,6,,8']]
     ds_actual, ds_expected = self._make_test_datasets(
         inputs, record_defaults=record_defaults)
-    with ops.Graph().as_default() as g:
-      self._assert_datasets_equal(g,
-                                  ds_actual.repeat(5).prefetch(1),
-                                  ds_expected.repeat(5).prefetch(1))
+    self._assert_datasets_equal(
+        ds_actual.repeat(5).prefetch(1),
+        ds_expected.repeat(5).prefetch(1))
 
   def testCsvDataset_withTypeDefaults(self):
     # Testing using dtypes as record_defaults for required fields
@@ -373,13 +375,11 @@
     ]]
     file_path = self._setup_files(data)
 
-    with ops.Graph().as_default() as g:
-      ds = readers.make_csv_dataset(
-          file_path, batch_size=1, shuffle=False, num_epochs=1)
-      next_batch = ds.make_one_shot_iterator().get_next()
+    ds = readers.make_csv_dataset(
+        file_path, batch_size=1, shuffle=False, num_epochs=1)
+    nxt = self._get_next(ds)
 
-    with self.session(graph=g) as sess:
-      result = list(sess.run(next_batch).values())
+    result = list(self.evaluate(nxt()).values())
 
     self.assertEqual(result, sorted(result))
 
@@ -542,6 +542,29 @@
         compression_type='ZLIB',
         record_defaults=record_defaults)
 
+  def testCsvDataset_withScalarDefaults(self):
+    record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+    inputs = [[',,,', '1,1,1,', ',2,2,2']]
+    self._test_dataset(
+        inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+        record_defaults=record_defaults)
+
+  def testCsvDataset_with2DDefaults(self):
+    record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+    inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+    if context.executing_eagerly():
+      err_spec = errors.InvalidArgumentError, (
+          'Each record default should be at '
+          'most rank 1.')
+    else:
+      err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+    with self.assertRaisesWithPredicateMatch(*err_spec):
+      self._test_dataset(
+          inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+          record_defaults=record_defaults)
+
 
 class CsvDatasetBenchmark(test.Benchmark):
   """Benchmarks for the various ways of creating a dataset from CSV files.
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9020a49..eb11032 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -38,7 +38,7 @@
     iterator = dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for _ in range(100):
         for i in range(10):
@@ -67,7 +67,7 @@
     iterator = dataset.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       freqs = np.zeros([num_datasets])
       for _ in range(num_samples):
         freqs[sess.run(next_element)] += 1
@@ -104,7 +104,7 @@
     iterator = dataset.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in choice_array:
         self.assertEqual(words[i], sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index e6883d5..f3968cd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -53,7 +53,7 @@
         lambda x: (x * x, make_sparse(x))).take(take_t)
     element = get_single_element.get_single_element(dataset)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if error is None:
         dense_val, sparse_val = sess.run(
             element, feed_dict={
@@ -90,7 +90,7 @@
     dataset = dataset_ops.Dataset.range(stop_t)
     element = get_single_element.reduce_dataset(dataset, sum_reducer)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       value = sess.run(element, feed_dict={stop_t: stop})
       self.assertEqual(stop * (stop - 1) / 2, value)
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index db2ab81..9c508d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -44,14 +44,14 @@
     get_op = gen_dataset_ops.indexed_dataset_get(
         handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(materialize)
       self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
 
   def testIdentityIndexedDataset(self):
     ds = indexed_dataset_ops.IdentityIndexedDataset(16)
     materialized = ds.materialize()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(materialized.initializer)
       placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
       for i in range(16):
@@ -66,7 +66,7 @@
     ds = indexed_dataset_ops.IdentityIndexedDataset(16)
     itr = ds.make_initializable_iterator()
     n = itr.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(itr.initializer)
       for i in range(16):
         output = sess.run(n)
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 7a3215f..b9e74df 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -177,7 +177,7 @@
   def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
     # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
     # `Dataset.flat_map()` and is single-threaded. No synchronization required.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -212,7 +212,7 @@
 
   def testSingleThreadedRagged(self):
     # Tests a sequence with wildly different elements per iterator.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -242,7 +242,7 @@
   def _testTwoThreadsNoContention(self, sloppy=False):
     # num_threads > 1.
     # Explicit coordination should result in `Dataset.interleave()` behavior
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -286,7 +286,7 @@
     Args:
       sloppy: Whether to be sloppy or not.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -328,7 +328,7 @@
   def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
     # num_threads > 1.
     # Explicit coordination should result in `Dataset.interleave()` behavior
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -373,7 +373,7 @@
     Args:
       sloppy: Whether to be sloppy or not.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -413,7 +413,7 @@
     self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
 
   def _testEmptyInput(self, sloppy=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Empty input.
       self._clear_coordination_events()
       sess.run(
@@ -437,7 +437,7 @@
 
   def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
     # Non-empty input leading to empty output.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -461,7 +461,7 @@
   def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
     race_indices = {2, 8, 14}  # Sequence points when sloppy mode has race conds
     # Mixture of non-empty and empty interleaved datasets.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -500,7 +500,7 @@
   def testDelayedOutputSloppy(self):
     # Explicitly control the sequence of events to ensure we correctly avoid
     # head-of-line blocking.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -525,7 +525,7 @@
         sess.run(self.next_element)
 
   def testBlockLengthWithContentionSloppy(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       done_first_event = False
       sess.run(
@@ -560,7 +560,7 @@
 
   def _testEarlyExit(self, sloppy=False):
     # Exiting without consuming all input should not block
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -604,7 +604,7 @@
             interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output_values = []
       for _ in range(30):
         output_values.append(sess.run(iterator.get_next()))
@@ -635,7 +635,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         for j in range(2):
@@ -645,7 +645,7 @@
         sess.run(get_next)
 
   def testErrorsInOutputFn(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._clear_coordination_events()
       sess.run(
           self.init_op,
@@ -704,7 +704,7 @@
     self.init_op = self.iterator.initializer
     self.next_element = self.iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_op,
           feed_dict={
@@ -753,7 +753,7 @@
     self.init_op = self.iterator.initializer
     self.next_element = self.iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_op,
           feed_dict={
@@ -792,7 +792,7 @@
     next_element = iterator.get_next()
 
     results = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(2):
         elements = []
         sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 7bc582e..1cc5ddc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -51,7 +51,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(num_repeats):  # Dataset is repeated.
         for i in range(10):  # 10 records.
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index dc9d56d..e851938 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -54,7 +54,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for x in [1., 2., 3., 5.]:
         self.assertEqual(x, sess.run(get_next))
@@ -72,7 +72,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for x in [1., 2., 3., 5.]:
         self.assertEqual(x, sess.run(get_next))
@@ -99,7 +99,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # All of the files are present.
       sess.run(init_op)
       for filename in filenames:
@@ -209,7 +209,7 @@
             end = time.time()
             chained_deltas.append(end - start)
 
-        fused_dataset = dataset = dataset.apply(
+        fused_dataset = dataset.apply(
             batching.map_and_batch(
                 math_ops.matmul,
                 num_parallel_calls=num_calls,
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 73cde40..83b7237 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@
 from __future__ import division
 from __future__ import print_function
 
+import time
+
 from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -25,10 +28,11 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
-
 class MapDefunTest(test.TestCase):
 
   def testMapDefunSimple(self):
@@ -130,6 +134,146 @@
     with self.assertRaises(errors.InvalidArgumentError):
       self.evaluate(result)
 
+  def testMapDefunCancelledCorrectly(self):
+
+    @function.Defun(dtypes.int64)
+    def defun(x):
+      # x has leading dimension 5, this will raise an error
+      return array_ops.gather(x, 10)
+
+    c = array_ops.tile(
+        array_ops.expand_dims(
+            constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+        [100, 1])
+    map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+    with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                 r"indices = 10 is not in \[0, 5\)"):
+      self.evaluate(map_defun_op)
+
+  def testMapDefunWithUnspecifiedOutputShape(self):
+
+    @function.Defun(dtypes.int32)
+    def simple_fn(x):
+      res = x * 2 + 3
+      return (res, res + 1, res + 2)
+
+    nums = [[1, 2], [3, 4], [5, 6]]
+    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+    r = map_defun.map_defun(simple_fn, [elems],
+                            [dtypes.int32, dtypes.int32, dtypes.int32],
+                            [None, (None,), (2,)])
+    expected = elems * 2 + 3
+    self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+    self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+    self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+  def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+    @function.Defun(dtypes.int32)
+    def simple_fn(x):
+      return x * 2 + 3
+
+    elems = array_ops.placeholder(dtypes.int32, name="data")
+    r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+    with session.Session() as sess:
+      self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+      self.assertAllEqual(
+          sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+  def testMapDefunWithWrongOutputShape(self):
+
+    @function.Defun(dtypes.int32)
+    def simple_fn(x):
+      return x * 2 + 3
+
+    nums = [[1, 2], [3, 4], [5, 6]]
+    elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+    r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+    with self.assertRaises(errors.InvalidArgumentError):
+      self.evaluate(r)
+
+  def testMapDefunWithInvalidInput(self):
+
+    @function.Defun(dtypes.int32)
+    def simple_fn(x):
+      return x * 2
+
+    c = constant_op.constant(2)
+    with self.assertRaises(ValueError):
+      # Fails at graph construction time for inputs with known shapes.
+      r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+    p = array_ops.placeholder(dtypes.int32)
+    r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+    with session.Session() as sess:
+      with self.assertRaises(errors.InvalidArgumentError):
+        sess.run(r, feed_dict={p: 0})
+
+  def _assert_op_cancelled(self, sess, map_defun_op):
+    with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
+      sess.run(map_defun_op)
+
+  def testMapDefunWithParentCancellation(self):
+    # Checks that a cancellation of the parent graph is threaded through to
+    # MapDefunOp correctly.
+    @function.Defun(dtypes.int32)
+    def simple_fn(x):
+      del x
+      queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
+      # Blocking
+      return queue.dequeue_many(5)
+
+    c = constant_op.constant([1, 2, 3, 4, 5])
+    map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
+
+    with self.test_session() as sess:
+      thread = self.checkedThread(
+          self._assert_op_cancelled, args=(sess, map_defun_op))
+      thread.start()
+      time.sleep(0.1)
+      sess.close()
+      thread.join()
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+  def _run(self, op, name=None, num_iters=3000):
+    with session.Session() as sess:
+      # Warm up the session
+      for _ in range(5):
+        sess.run(op)
+      start = time.time()
+      for _ in range(num_iters):
+        sess.run(op)
+      end = time.time()
+      mean_us = (end - start) * 1e6 / num_iters
+      self.report_benchmark(
+          name=name,
+          iters=num_iters,
+          wall_time=mean_us,
+          extras={"examples_per_sec": num_iters / (end - start)})
+
+  def benchmarkDefunVsMapFn(self):
+    """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+    @function.Defun(dtypes.int32)
+    def defun(x):
+      return array_ops.identity(x)
+
+    def map_fn(x):
+      return array_ops.identity(x)
+
+    base = math_ops.range(100)
+    for input_size in [10, 100, 1000, 10000]:
+      num_iters = 100000 // input_size
+      map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+      map_fn_op = functional_ops.map_fn(map_fn, base)
+
+      self._run(
+          map_defun_op,
+          "benchmarkMapDefun_size_%d" % input_size,
+          num_iters=num_iters)
+      self._run(
+          map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index b299e07..7e9ea68 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -7,6 +7,34 @@
 load("//tensorflow:tensorflow.bzl", "py_test")
 
 py_test(
+    name = "assert_next_dataset_op_test",
+    size = "medium",
+    srcs = ["assert_next_dataset_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/contrib/data/python/ops:optimization",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/ops:dataset_ops",
+    ],
+)
+
+py_test(
+    name = "latency_all_edges_test",
+    size = "small",
+    srcs = ["latency_all_edges_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+        "//tensorflow/contrib/data/python/ops:optimization",
+        "//tensorflow/contrib/data/python/ops:stats_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/ops:dataset_ops",
+    ],
+)
+
+py_test(
     name = "map_vectorization_test",
     size = "small",
     srcs = ["map_vectorization_test.py"],
@@ -46,16 +74,34 @@
 )
 
 py_test(
-    name = "latency_all_edges_test",
-    size = "small",
-    srcs = ["latency_all_edges_test.py"],
+    name = "model_dataset_op_test",
+    size = "medium",
+    srcs = ["model_dataset_op_test.py"],
     srcs_version = "PY2AND3",
+    tags = [
+        "optonly",
+    ],
     deps = [
-        "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+        "//tensorflow/contrib/data/python/ops:batching",
+        "//tensorflow/contrib/data/python/ops:interleave_ops",
         "//tensorflow/contrib/data/python/ops:optimization",
-        "//tensorflow/contrib/data/python/ops:stats_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:errors",
         "//tensorflow/python/data/ops:dataset_ops",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "optimize_dataset_op_test",
+    size = "small",
+    srcs = ["optimize_dataset_op_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/contrib/data/python/ops:optimization",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000..bd7b50b
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test.TestCase):
+
+  def testAssertNext(self):
+    dataset = dataset_ops.Dataset.from_tensors(0).apply(
+        optimization.assert_next(["Map"])).map(lambda x: x)
+    iterator = dataset.make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      self.assertEqual(0, sess.run(get_next))
+
+  def testAssertNextInvalid(self):
+    dataset = dataset_ops.Dataset.from_tensors(0).apply(
+        optimization.assert_next(["Whoops"])).map(lambda x: x)
+    iterator = dataset.make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          "Asserted Whoops transformation at offset 0 but encountered "
+          "Map transformation instead."):
+        sess.run(get_next)
+
+  def testAssertNextShort(self):
+    dataset = dataset_ops.Dataset.from_tensors(0).apply(
+        optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+    iterator = dataset.make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          "Asserted next 2 transformations but encountered only 1."):
+        sess.run(get_next)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
index 1850b69..db380c0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -40,7 +40,7 @@
     get_next = iterator.get_next()
     summary_t = stats_aggregator.get_summary()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       self.assertEqual(1 * 1, sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 586b4be..dde1159 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -44,22 +44,22 @@
     for i, fun1 in enumerate(functions):
       for j, fun2 in enumerate(functions):
         tests.append((
-            "test_{}_{}".format(i, j),
+            "Test{}{}".format(i, j),
             [fun1, fun2],
         ))
         for k, fun3 in enumerate(functions):
           tests.append((
-              "test_{}_{}_{}".format(i, j, k),
+              "Test{}{}{}".format(i, j, k),
               [fun1, fun2, fun3],
           ))
 
     swap = lambda x, n: (n, x)
     tests.append((
-        "swap1",
+        "Swap1",
         [lambda x: (x, 42), swap],
     ))
     tests.append((
-        "swap2",
+        "Swap2",
         [lambda x: (x, 42), swap, swap],
     ))
     return tuple(tests)
@@ -74,7 +74,7 @@
     dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for x in range(5):
         result = sess.run(get_next)
         r = x
@@ -109,13 +109,13 @@
 
     for x, fun in enumerate(functions):
       for y, predicate in enumerate(filters):
-        tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
+        tests.append(("Mixed{}{}".format(x, y), fun, predicate))
 
     # Multi output
-    tests.append(("multiOne", lambda x: (x, x),
+    tests.append(("Multi1", lambda x: (x, x),
                   lambda x, y: constant_op.constant(True)))
     tests.append(
-        ("multiTwo", lambda x: (x, 2),
+        ("Multi2", lambda x: (x, 2),
          lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
     return tuple(tests)
 
@@ -131,7 +131,7 @@
   def _testMapAndFilter(self, dataset, function, predicate):
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for x in range(10):
         r = function(x)
         if isinstance(r, tuple):
@@ -172,17 +172,17 @@
     identity = lambda x: x
     for x, predicate_1 in enumerate(filters):
       for y, predicate_2 in enumerate(filters):
-        tests.append(("mixed_{}_{}".format(x, y), identity,
+        tests.append(("Mixed{}{}".format(x, y), identity,
                       [predicate_1, predicate_2]))
         for z, predicate_3 in enumerate(filters):
-          tests.append(("mixed_{}_{}_{}".format(x, y, z), identity,
+          tests.append(("Mixed{}{}{}".format(x, y, z), identity,
                         [predicate_1, predicate_2, predicate_3]))
 
     take_all_multiple = lambda x, y: constant_op.constant(True)
     # Multi output
-    tests.append(("multiOne", lambda x: (x, x),
+    tests.append(("Multi1", lambda x: (x, x),
                   [take_all_multiple, take_all_multiple]))
-    tests.append(("multiTwo", lambda x: (x, 2), [
+    tests.append(("Multi2", lambda x: (x, 2), [
         take_all_multiple,
         lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
     ]))
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
new file mode 100644
index 0000000..0a87d3e
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -0,0 +1,177 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ModelDatasetTest(test.TestCase):
+
+  def testModelMap(self):
+    k = 1024 * 1024
+    dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+                                                np.random.rand(4 * k,
+                                                               1))).repeat()
+    dataset = dataset.map(math_ops.matmul)
+    iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    deltas = []
+    with self.test_session() as sess:
+      for _ in range(5):
+        sess.run(get_next.op)
+      for _ in range(100):
+        start = time.time()
+        sess.run(get_next.op)
+        end = time.time()
+        deltas.append(end - start)
+
+    print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+          (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+           np.max(deltas)))
+
+  def testModelParallelMap(self):
+    k = 1024 * 1024
+    dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+                                                np.random.rand(4 * k,
+                                                               1))).repeat()
+    dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
+    iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    deltas = []
+    with self.test_session() as sess:
+      for _ in range(5):
+        sess.run(get_next.op)
+      for _ in range(1000):
+        start = time.time()
+        sess.run(get_next.op)
+        end = time.time()
+        deltas.append(end - start)
+
+    print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+          (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+           np.max(deltas)))
+
+  def testModelMapAndBatch(self):
+    batch_size = 16
+    k = 1024 * 1024
+    dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+                                                np.random.rand(4 * k,
+                                                               1))).repeat()
+    dataset = dataset.apply(
+        batching.map_and_batch(
+            math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
+    iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    deltas = []
+    with self.test_session() as sess:
+      for _ in range(5):
+        sess.run(get_next.op)
+      for _ in range(10):
+        start = time.time()
+        sess.run(get_next.op)
+        end = time.time()
+        deltas.append(end - start)
+
+    print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+          (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+           np.max(deltas)))
+
+  def testModelParallelInterleave(self):
+    k = 1024 * 1024
+    dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+                                                np.random.rand(4 * k,
+                                                               1))).repeat()
+    dataset = dataset.map(math_ops.matmul)
+    dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+        lambda _: dataset, cycle_length=56, num_parallel_calls=56)
+    iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    deltas = []
+    with self.test_session() as sess:
+      for _ in range(5):
+        sess.run(get_next.op)
+      for _ in range(1000):
+        start = time.time()
+        sess.run(get_next.op)
+        end = time.time()
+        deltas.append(end - start)
+
+    print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+          (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+           np.max(deltas)))
+
+  def testModelNested(self):
+    k = 1024 * 1024
+    a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+    b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+    c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+    dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+    def f1(a, b, c):
+      x, y = a
+      return math_ops.matmul(x, y), b, c
+
+    def f2(a, b, c):
+      x, y = b
+      return a, math_ops.matmul(x, y), c
+
+    def f3(a, b, c):
+      x, y = c
+      return a, b, math_ops.matmul(x, y)
+
+    dataset = dataset.map(f1, num_parallel_calls=32)
+    dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+        lambda _: dataset, cycle_length=2)
+
+    dataset = dataset.map(f2, num_parallel_calls=16)
+    dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+        lambda _: dataset, cycle_length=2)
+
+    dataset = dataset.map(f3, num_parallel_calls=10)
+    iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    deltas = []
+    with self.test_session() as sess:
+      for _ in range(5):
+        sess.run(get_next)
+      for _ in range(100):
+        start = time.time()
+        sess.run(get_next)
+        end = time.time()
+        deltas.append(end - start)
+
+    print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+          (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+           np.max(deltas)))
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
similarity index 75%
rename from tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
rename to tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 0897171..909da5a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.contrib.data.python.ops import optimization
@@ -29,41 +28,7 @@
 from tensorflow.python.platform import test
 
 
-class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
-
-  def testAssertSuffix(self):
-    dataset = dataset_ops.Dataset.from_tensors(0).apply(
-        optimization.assert_next(["Map"])).map(lambda x: x)
-    iterator = dataset.make_one_shot_iterator()
-    get_next = iterator.get_next()
-
-    with self.test_session() as sess:
-      self.assertEqual(0, sess.run(get_next))
-
-  def testAssertSuffixInvalid(self):
-    dataset = dataset_ops.Dataset.from_tensors(0).apply(
-        optimization.assert_next(["Whoops"])).map(lambda x: x)
-    iterator = dataset.make_one_shot_iterator()
-    get_next = iterator.get_next()
-
-    with self.test_session() as sess:
-      with self.assertRaisesRegexp(
-          errors.InvalidArgumentError,
-          "Asserted Whoops transformation at offset 0 but encountered "
-          "Map transformation instead."):
-        sess.run(get_next)
-
-  def testAssertSuffixShort(self):
-    dataset = dataset_ops.Dataset.from_tensors(0).apply(
-        optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
-    iterator = dataset.make_one_shot_iterator()
-    get_next = iterator.get_next()
-
-    with self.test_session() as sess:
-      with self.assertRaisesRegexp(
-          errors.InvalidArgumentError,
-          "Asserted next 2 transformations but encountered only 1."):
-        sess.run(get_next)
+class OptimizeDatasetTest(test.TestCase):
 
   def testOptimizationDefault(self):
     dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index f6c4a98..c4623bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -80,7 +80,7 @@
             expected_values=None,
             expected_err=None):
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 361fe0d..0166ba0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,7 +235,7 @@
       destroy_op = resource_variable_ops.destroy_resource_op(
           buffer_resource_handle, ignore_lookup_error=True)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual([b"a"], sess.run(prefetch_op))
       self.assertEqual([b"b"], sess.run(prefetch_op))
       self.assertEqual([b"c"], sess.run(prefetch_op))
@@ -301,7 +301,7 @@
     self.assertEqual(dtypes.int64, next_element.dtype)
     self.assertEqual([], next_element.shape)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual(i, sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -384,7 +384,7 @@
     iterator = device_dataset.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(10):
         self.assertEqual(i, sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -435,7 +435,7 @@
     iterator = device_dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(5):
         self.assertEqual(i, sess.run(next_element))
@@ -683,7 +683,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(10):
         self.assertEqual(i, sess.run(next_element))
@@ -702,7 +702,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(10):
         self.assertEqual(i, sess.run(next_element))
@@ -721,7 +721,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -739,7 +739,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -757,7 +757,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -775,7 +775,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
       with self.assertRaises(errors.OutOfRangeError):
@@ -796,7 +796,7 @@
         iterator = back_to_cpu_dataset.make_initializable_iterator()
         next_element = iterator.get_next()
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(iterator.initializer)
         for i in range(10):
           self.assertEqual(i, sess.run(next_element))
@@ -875,7 +875,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(5):
         self.assertEqual(i, sess.run(next_element))
@@ -897,7 +897,7 @@
       iterator = device_dataset.make_initializable_iterator()
       next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       for i in range(5):
         self.assertEqual(i, sess.run(next_element))
@@ -920,7 +920,7 @@
       elem_has_value_t = next_elem.has_value()
       elem_value_t = next_elem.get_value()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Before initializing the iterator, evaluating the optional fails with
       # a FailedPreconditionError.
       with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 592642d..db8fe6a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -43,7 +43,7 @@
     self.assertEqual([tensor_shape.TensorShape([])] * 3,
                      [t.shape for t in get_next[1]])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
       self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
@@ -63,7 +63,7 @@
                          .make_one_shot_iterator())
     negative_get_next = negative_iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(3, sess.run(get_next))
       self.assertEqual(3 + 4, sess.run(get_next))
       self.assertEqual(3 + 2 * 4, sess.run(get_next))
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index fd00cdc..ed75b27 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -116,7 +116,7 @@
     init_op = iterator.initializer
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
           range(self._num_files), 2, 10):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c5cfddb..16b1441 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -77,7 +77,7 @@
             class_func=lambda c, _: c,
             seed=27)).make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       returned = []
       while len(returned) < 4000:
         returned.append(sess.run(get_next))
@@ -115,7 +115,7 @@
 
     get_next = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       returned = []
       with self.assertRaises(errors.OutOfRangeError):
         while True:
@@ -146,7 +146,7 @@
 
     get_next = dataset.make_one_shot_iterator().get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       returned = []
       with self.assertRaises(errors.OutOfRangeError):
         while True:
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 42cada0..dde678b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -50,7 +50,7 @@
         start, make_scan_fn(step)).take(take).make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
                                             (10, 2, 10), (10, -1, 10),
@@ -100,7 +100,7 @@
         make_scan_fn(step)).take(take).make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
                                             (10, 2, 10), (10, -1, 10),
@@ -133,7 +133,7 @@
     iterator = dataset.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(5):
         (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
         self.assertAllEqual([0] * (2**i), longer_vector_val)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 4881f63..aa89674 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python/data/ops:dataset_ops",
         "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892f..243f640 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@
 
 
 class InterleaveDatasetSerializationTest(
-    dataset_serialization_test_base.DatasetSerializationTestBase):
+    dataset_serialization_test_base.DatasetSerializationTestBase,
+    parameterized.TestCase):
 
-  def _build_iterator_graph(self, input_values, cycle_length, block_length):
+  def _build_iterator_graph(self, input_values, cycle_length, block_length,
+                            num_parallel_calls):
     repeat_count = 2
     return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
         repeat_count).interleave(
             lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
-            cycle_length, block_length)
+            cycle_length, block_length, num_parallel_calls)
 
-  def testSerializationCore(self):
+  @parameterized.named_parameters(
+      ("1", 2, 3, None),
+      ("2", 2, 3, 1),
+      ("3", 2, 3, 2),
+      ("4", 1, 3, None),
+      ("5", 1, 3, 1),
+      ("6", 2, 1, None),
+      ("7", 2, 1, 1),
+      ("8", 2, 1, 2),
+  )
+  def testSerializationCore(self, cycle_length, block_length,
+                            num_parallel_calls):
     input_values = np.array([4, 5, 6], dtype=np.int64)
     num_outputs = np.sum(input_values) * 2
-    # cycle_length > 1, block_length > 1
-    cycle_length = 2
-    block_length = 3
     # pylint: disable=g-long-lambda
     self.run_core_tests(
         lambda: self._build_iterator_graph(
-            input_values, cycle_length, block_length),
+            input_values, cycle_length, block_length, num_parallel_calls),
         lambda: self._build_iterator_graph(
-            input_values, cycle_length * 2, block_length * 1),
+            input_values, cycle_length * 2, block_length, num_parallel_calls),
         num_outputs)
-    # cycle_length = 1
-    cycle_length = 1
-    block_length = 3
-    self.run_core_tests(
-        lambda: self._build_iterator_graph(
-            input_values, cycle_length, block_length),
-        None, num_outputs)
-    # block_length = 1
-    cycle_length = 2
-    block_length = 1
-    self.run_core_tests(
-        lambda: self._build_iterator_graph(
-            input_values, cycle_length, block_length),
-        None, num_outputs)
     # pylint: enable=g-long-lambda
 
   def testSparseCore(self):
@@ -82,5 +79,5 @@
     self.run_core_tests(_build_dataset, None, 20)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 077abd6..440e48d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -35,7 +35,7 @@
   def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
     get_next = ds_fn().make_one_shot_iterator().get_next()
     outputs = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(num_outputs):
         outputs.append(sess.run(get_next))
       if verify_exhausted:
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846..90d18dc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@
 
 class SlideDatasetTest(test.TestCase, parameterized.TestCase):
 
-  @parameterized.parameters(
-      (20, 14, 7, 1),
-      (20, 17, 9, 1),
-      (20, 14, 14, 1),
-      (20, 10, 14, 1),
-      (20, 14, 19, 1),
-      (20, 4, 1, 2),
-      (20, 2, 1, 6),
-      (20, 4, 7, 2),
-      (20, 2, 7, 6),
-      (1, 10, 4, 1),
-      (0, 10, 4, 1),
+  @parameterized.named_parameters(
+      ("1", 20, 14, 7, 1),
+      ("2", 20, 17, 9, 1),
+      ("3", 20, 14, 14, 1),
+      ("4", 20, 10, 14, 1),
+      ("5", 20, 14, 19, 1),
+      ("6", 20, 4, 1, 2),
+      ("7", 20, 2, 1, 6),
+      ("8", 20, 4, 7, 2),
+      ("9", 20, 2, 7, 6),
+      ("10", 1, 10, 4, 1),
+      ("11", 0, 10, 4, 1),
   )
   def testSlideDataset(self, count, window_size, window_shift, window_stride):
     """Tests a dataset that slides a window its input elements."""
@@ -75,7 +75,7 @@
     self.assertEqual([[None] + list(c.shape[1:]) for c in components],
                      [t.shape.as_list() for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -96,18 +96,18 @@
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @parameterized.parameters(
-      (20, 14, 7, 1),
-      (20, 17, 9, 1),
-      (20, 14, 14, 1),
-      (20, 10, 14, 1),
-      (20, 14, 19, 1),
-      (20, 4, 1, 2),
-      (20, 2, 1, 6),
-      (20, 4, 7, 2),
-      (20, 2, 7, 6),
-      (1, 10, 4, 1),
-      (0, 10, 4, 1),
+  @parameterized.named_parameters(
+      ("1", 20, 14, 7, 1),
+      ("2", 20, 17, 9, 1),
+      ("3", 20, 14, 14, 1),
+      ("4", 20, 10, 14, 1),
+      ("5", 20, 14, 19, 1),
+      ("6", 20, 4, 1, 2),
+      ("7", 20, 2, 1, 6),
+      ("8", 20, 4, 7, 2),
+      ("9", 20, 2, 7, 6),
+      ("10", 1, 10, 4, 1),
+      ("11", 0, 10, 4, 1),
   )
   def testSlideDatasetDeprecated(self, count, window_size, stride,
                                  window_stride):
@@ -139,7 +139,7 @@
     self.assertEqual([[None] + list(c.shape[1:]) for c in components],
                      [t.shape.as_list() for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -160,10 +160,10 @@
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  @parameterized.parameters(
-      (14, 0, 3, 1),
-      (14, 3, 0, 1),
-      (14, 3, 3, 0),
+  @parameterized.named_parameters(
+      ("1", 14, 0, 3, 1),
+      ("2", 14, 3, 0, 1),
+      ("3", 14, 3, 3, 0),
   )
   def testSlideDatasetInvalid(self, count, window_size, window_shift,
                               window_stride):
@@ -180,7 +180,7 @@
                 window_stride=window_stride_t)).make_initializable_iterator())
     init_op = iterator.initializer
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(
             init_op,
@@ -214,7 +214,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       num_batches = (10 - 5) // 3 + 1
       for i in range(num_batches):
@@ -243,7 +243,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       num_batches = (10 - 5) // 3 + 1
       for i in range(num_batches):
@@ -277,7 +277,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       # Slide: 1st batch.
       actual = sess.run(get_next)
@@ -316,7 +316,7 @@
         .make_initializable_iterator())
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 2c2cfbe..52823d3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,7 +30,7 @@
   def testReadResultSet(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string), 2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(2):  # Run twice to verify statelessness of db operations.
         sess.run(
             init_op,
@@ -48,7 +48,7 @@
   def testReadResultSetJoinQuery(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -67,7 +67,7 @@
   def testReadResultSetNullTerminator(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -86,7 +86,7 @@
   def testReadResultSetReuseSqlDataset(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -114,7 +114,7 @@
   def testReadEmptyResultSet(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -128,7 +128,7 @@
   def testReadResultSetWithInvalidDriverName(self):
     init_op = self._createSqlDataset((dtypes.string, dtypes.string,
                                       dtypes.string))[0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(
             init_op,
@@ -142,7 +142,7 @@
   def testReadResultSetWithInvalidColumnName(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -157,7 +157,7 @@
   def testReadResultSetOfQueryWithSyntaxError(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -173,7 +173,7 @@
   def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -190,7 +190,7 @@
   def testReadResultSetOfInsertQuery(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.string))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -205,7 +205,7 @@
   # place it in an `int8` tensor.
   def testReadResultSetInt8(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -222,7 +222,7 @@
   def testReadResultSetInt8NegativeAndZero(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
                                                 dtypes.int8))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -238,7 +238,7 @@
   # a SQLite database table and place it in an `int8` tensor.
   def testReadResultSetInt8MaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -256,7 +256,7 @@
   # place it in an `int16` tensor.
   def testReadResultSetInt16(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -273,7 +273,7 @@
   def testReadResultSetInt16NegativeAndZero(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
                                                 dtypes.int16))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -289,7 +289,7 @@
   # a SQLite database table and place it in an `int16` tensor.
   def testReadResultSetInt16MaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -307,7 +307,7 @@
   # place it in an `int32` tensor.
   def testReadResultSetInt32(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -321,7 +321,7 @@
   # SQLite database table and place it in an `int32` tensor.
   def testReadResultSetInt32NegativeAndZero(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -337,7 +337,7 @@
   # a SQLite database table and place it in an `int32` tensor.
   def testReadResultSetInt32MaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -355,7 +355,7 @@
   # table and place it in an `int32` tensor.
   def testReadResultSetInt32VarCharColumnAsInt(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -371,7 +371,7 @@
   # and place it in an `int64` tensor.
   def testReadResultSetInt64(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -387,7 +387,7 @@
   # SQLite database table and place it in an `int64` tensor.
   def testReadResultSetInt64NegativeAndZero(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -403,7 +403,7 @@
   # a SQLite database table and place it in an `int64` tensor.
   def testReadResultSetInt64MaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -422,7 +422,7 @@
   # place it in a `uint8` tensor.
   def testReadResultSetUInt8(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -438,7 +438,7 @@
   # SQLite database table and place them in `uint8` tensors.
   def testReadResultSetUInt8MinAndMaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -456,7 +456,7 @@
   # and place it in a `uint16` tensor.
   def testReadResultSetUInt16(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -472,7 +472,7 @@
   # SQLite database table and place them in `uint16` tensors.
   def testReadResultSetUInt16MinAndMaxValues(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -491,7 +491,7 @@
   # in `bool` tensors.
   def testReadResultSetBool(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -508,7 +508,7 @@
   # from a SQLite database table and place it as `True` in a `bool` tensor.
   def testReadResultSetBoolNotZeroOrOne(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -525,7 +525,7 @@
   def testReadResultSetFloat64(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.float64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -544,7 +544,7 @@
   def testReadResultSetFloat64OverlyPrecise(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.float64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -570,7 +570,7 @@
   def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
     init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
                                                 dtypes.float64))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 43067b4..e25570c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -75,6 +75,31 @@
         sess.run(next_element)
       self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
 
+  def testPrefetchBufferUtilization(self):
+    stats_aggregator = stats_ops.StatsAggregator()
+    dataset = dataset_ops.Dataset.range(100).map(
+        lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+            -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+    iterator = dataset.make_initializable_iterator()
+    next_element = iterator.get_next()
+    summary_t = stats_aggregator.get_summary()
+
+    with self.test_session() as sess:
+      sess.run(iterator.initializer)
+      for i in range(100):
+        self.assertAllEqual(
+            np.array([i] * i, dtype=np.int64), sess.run(next_element))
+        summary_str = sess.run(summary_t)
+        self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+                                    float(i + 1))
+        self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+                                    0, 1)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+      summary_str = sess.run(summary_t)
+      self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+                                  100)
+
   def testReinitialize(self):
     stats_aggregator = stats_ops.StatsAggregator()
     dataset = dataset_ops.Dataset.range(100).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 9a13acf..2f5a444 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -34,6 +34,16 @@
         return
     self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
 
+  def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+    summary_proto = summary_pb2.Summary()
+    summary_proto.ParseFromString(summary_str)
+    for value in summary_proto.value:
+      if tag == value.tag:
+        self.assertLessEqual(min_value, value.histo.min)
+        self.assertGreaterEqual(max_value, value.histo.max)
+        return
+    self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
   def _assertSummaryHasSum(self, summary_str, tag, expected_value):
     summary_proto = summary_pb2.Summary()
     summary_proto.ParseFromString(summary_str)
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1d70b16..4c3353f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -31,7 +31,7 @@
     # TODO(rachelim): support sparse tensor outputs
     next1 = dataset1.make_one_shot_iterator().get_next()
     next2 = dataset2.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       while True:
         try:
           op1 = sess.run(next1)
@@ -52,9 +52,12 @@
                                         dataset2,
                                         exception_class,
                                         replacements=None):
-    next1 = dataset1.make_one_shot_iterator().get_next()
-    next2 = dataset2.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    # We are defining next1 and next2 in the same line so that we get identical
+    # file:line_number in the error messages
+    # pylint: disable=line-too-long
+    next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
+    # pylint: enable=line-too-long
+    with self.cached_session() as sess:
       try:
         sess.run(next1)
         raise ValueError(
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2b..8d335e8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@
 
 class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
 
-  @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
-                            (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+  @parameterized.named_parameters(
+      ("1", 1, None),
+      ("2", 2, None),
+      ("3", 4, None),
+      ("4", 8, None),
+      ("5", 16, None),
+      ("6", 4, -1),
+      ("7", 4, 0),
+      ("8", 4, 1),
+      ("9", 4, 4),
+  )
   def testNumThreads(self, num_threads, max_intra_op_parallelism):
 
     def get_thread_id(_):
@@ -60,7 +69,7 @@
     iterator = dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       thread_ids = []
       try:
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index d79a842..f994c85 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -45,7 +45,7 @@
     iterator = dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for test_case, expected in test_cases:
         current_test_case = test_case
         sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6..6eaa0b19 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@
     else:
       self.assertEqual(xs, ys)
 
-  @parameterized.parameters(
-      (None, np.int32([]), dtypes.bool),
-      (None, np.int32([]), dtypes.int32),
-      (None, np.int32([]), dtypes.float32),
-      (None, np.int32([]), dtypes.string),
-      (None, np.int32([2]), dtypes.int32),
-      (None, np.int32([2, 2]), dtypes.int32),
-      ((None, None, None), np.int32([]), dtypes.int32),
-      ((None, (None, None)), np.int32([]), dtypes.int32),
+  @parameterized.named_parameters(
+      ("1", None, np.int32([]), dtypes.bool),
+      ("2", None, np.int32([]), dtypes.int32),
+      ("3", None, np.int32([]), dtypes.float32),
+      ("4", None, np.int32([]), dtypes.string),
+      ("5", None, np.int32([2]), dtypes.int32),
+      ("6", None, np.int32([2, 2]), dtypes.int32),
+      ("7", (None, None, None), np.int32([]), dtypes.int32),
+      ("8", (None, (None, None)), np.int32([]), dtypes.int32),
   )
   def testWindowDatasetFlatMap(self, structure, shape, dtype):
     """Tests windowing by chaining it with flat map.
@@ -92,20 +92,20 @@
     dataset = self._structuredDataset(structure, shape, dtype).apply(
         grouping.window_dataset(5)).flat_map(fn)
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       expected = sess.run(self._structuredElement(structure, shape, dtype))
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (None, np.int32([]), dtypes.bool),
-      (None, np.int32([]), dtypes.int32),
-      (None, np.int32([]), dtypes.float32),
-      (None, np.int32([]), dtypes.string),
-      (None, np.int32([2]), dtypes.int32),
-      (None, np.int32([2, 2]), dtypes.int32),
-      ((None, None, None), np.int32([]), dtypes.int32),
-      ((None, (None, None)), np.int32([]), dtypes.int32),
+  @parameterized.named_parameters(
+      ("1", None, np.int32([]), dtypes.bool),
+      ("2", None, np.int32([]), dtypes.int32),
+      ("3", None, np.int32([]), dtypes.float32),
+      ("4", None, np.int32([]), dtypes.string),
+      ("5", None, np.int32([2]), dtypes.int32),
+      ("6", None, np.int32([2, 2]), dtypes.int32),
+      ("7", (None, None, None), np.int32([]), dtypes.int32),
+      ("8", (None, (None, None)), np.int32([]), dtypes.int32),
   )
   def testWindowDatasetBatchDense(self, structure, shape, dtype):
     """Tests batching of dense tensor windows.
@@ -128,17 +128,17 @@
     dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
         grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       expected = sess.run(
           self._structuredElement(structure, np.concatenate(
               ([5], shape), axis=0), dtype))
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int32([]),),
-      (np.int32([1]),),
-      (np.int32([1, 2, 3]),),
+  @parameterized.named_parameters(
+      ("1", np.int32([])),
+      ("2", np.int32([1])),
+      ("3", np.int32([1, 2, 3])),
   )
   def testWindowDatasetBatchDenseDynamicShape(self, shape):
     """Tests batching of dynamically shaped dense tensor windows.
@@ -155,7 +155,7 @@
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {shape_t: shape})
       expected = sess.run(
           self._structuredElement(None, np.concatenate(([5], shape), axis=0),
@@ -203,15 +203,15 @@
           for substructure in structure
       ])
 
-  @parameterized.parameters(
-      (None, np.int32([]), dtypes.bool),
-      (None, np.int32([]), dtypes.int32),
-      (None, np.int32([]), dtypes.float32),
-      (None, np.int32([]), dtypes.string),
-      (None, np.int32([2]), dtypes.int32),
-      (None, np.int32([2, 2]), dtypes.int32),
-      ((None, None, None), np.int32([]), dtypes.int32),
-      ((None, (None, None)), np.int32([]), dtypes.int32),
+  @parameterized.named_parameters(
+      ("1", None, np.int32([]), dtypes.bool),
+      ("2", None, np.int32([]), dtypes.int32),
+      ("3", None, np.int32([]), dtypes.float32),
+      ("4", None, np.int32([]), dtypes.string),
+      ("5", None, np.int32([2]), dtypes.int32),
+      ("6", None, np.int32([2, 2]), dtypes.int32),
+      ("7", (None, None, None), np.int32([]), dtypes.int32),
+      ("8", (None, (None, None)), np.int32([]), dtypes.int32),
   )
   def testWindowDatasetBatchSparse(self, structure, shape, dtype):
     """Tests batching of sparse tensor windows.
@@ -235,7 +235,7 @@
         structure, shape, dtype).repeat(5).apply(
             grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       expected = sess.run(
           self._structuredSparseElement(structure,
                                         np.concatenate(([5], shape), axis=0),
@@ -243,10 +243,10 @@
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int32([]),),
-      (np.int32([1]),),
-      (np.int32([1, 2, 3]),),
+  @parameterized.named_parameters(
+      ("1", np.int32([])),
+      ("2", np.int32([1])),
+      ("3", np.int32([1, 2, 3])),
   )
   def testWindowDatasetBatchSparseDynamicShape(self, shape):
     """Tests batching of dynamically shaped sparse tensor windows.
@@ -263,7 +263,7 @@
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {shape_t: shape})
       expected = sess.run(
           self._structuredSparseElement(None,
@@ -284,17 +284,18 @@
               for substructure in structure
           ]))
 
-  @parameterized.parameters(
-      (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
-      (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
-      (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
-      (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
-      (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
-      ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
-      ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+  @parameterized.named_parameters(
+      ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+      ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+      ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+      ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+      ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+      ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("8", (None,
+             (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
   )
   def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
                                         padded_shape):
@@ -320,7 +321,7 @@
         grouping.window_dataset(len(shapes))).apply(
             grouping._map_x_dataset(fn))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
       expected = sess.run(
           self._structuredElement(
@@ -329,10 +330,10 @@
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int32([[1], [2], [3]]), [-1]),
-      (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
-      (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+  @parameterized.named_parameters(
+      ("1", np.int32([[1], [2], [3]]), [-1]),
+      ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+      ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
   )
   def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
     """Tests padded batching of dynamically shaped dense tensor windows.
@@ -351,7 +352,7 @@
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {shapes_t: shapes})
       expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
       expected = sess.run(
@@ -361,9 +362,9 @@
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int32([[1]]), np.int32([0])),
-      (np.int32([[10], [20]]), np.int32([15])),
+  @parameterized.named_parameters(
+      ("1", np.int32([[1]]), np.int32([0])),
+      ("2", np.int32([[10], [20]]), np.int32([15])),
   )
   def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
     """Tests invalid padded batching of dense tensor windows.
@@ -379,7 +380,7 @@
                 grouping._map_x_dataset(
                     lambda x: batching.padded_batch_window(x, padded_shape)))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
 
@@ -420,17 +421,18 @@
           for substructure in structure
       ])
 
-  @parameterized.parameters(
-      (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
-      (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
-      (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
-      (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
-      (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
-      ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
-      ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
-      (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+  @parameterized.named_parameters(
+      ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+      ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+      ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+      ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+      ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+      ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("8", (None,
+             (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+      ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
   )
   def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
                                          padded_shape):
@@ -456,17 +458,17 @@
         structure, shapes, dtype).apply(grouping.window_dataset(
             len(shapes))).apply(grouping._map_x_dataset(fn))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       expected = sess.run(
           self._structuredRaggedSparseElement(structure, shapes, dtype,
                                               padded_shape))
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int64([[1], [2], [3]]), [-1]),
-      (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
-      (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+  @parameterized.named_parameters(
+      ("1", np.int64([[1], [2], [3]]), [-1]),
+      ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+      ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
   )
   def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
                                                      padded_shape):
@@ -487,7 +489,7 @@
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, {shapes_t: shapes})
       expected = sess.run(
           self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
@@ -495,9 +497,9 @@
       actual = sess.run(get_next)
       self._assertEqual(expected, actual)
 
-  @parameterized.parameters(
-      (np.int64([[1]]), [0]),
-      (np.int64([[10], [20]]), [15]),
+  @parameterized.named_parameters(
+      ("1", np.int64([[1]]), [0]),
+      ("2", np.int64([[10], [20]]), [15]),
   )
   def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
     """Tests invalid padded batching of sparse tensor windows.
@@ -514,7 +516,7 @@
             grouping._map_x_dataset(
                 lambda x: batching.padded_batch_window(x, padded_shape)))
     get_next = dataset.make_one_shot_iterator().get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index c603ecc..867ee2b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -61,7 +61,7 @@
     return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
 
   def testWrite(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.writer, feed_dict={
               self.filename: self._createFile(),
@@ -71,7 +71,7 @@
 
   def testWriteZLIB(self):
     options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.writer,
           feed_dict={
@@ -84,7 +84,7 @@
 
   def testWriteGZIP(self):
     options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.writer,
           feed_dict={
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 9c2001c..367c159 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -272,9 +272,9 @@
       padding_value = 0
 
   def batch_init_fn(_):
-    return array_ops.fill(
-        array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
-        constant_op.constant(padding_value, dtype=dataset.output_types))
+    batch_shape = array_ops.concat(
+        [np.array([0], dtype=np.int32), padded_shape], 0)
+    return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
 
   def batch_reduce_fn(state, value):
     return array_ops.concat([state, [value]], 0)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 6edc1d7..099e10d 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -124,7 +124,8 @@
                               bucket_batch_sizes,
                               padded_shapes=None,
                               padding_values=None,
-                              pad_to_bucket_boundary=False):
+                              pad_to_bucket_boundary=False,
+                              no_padding=False):
   """A transformation that buckets elements in a `Dataset` by length.
 
   Elements of the `Dataset` are grouped together by length and then are padded
@@ -152,6 +153,8 @@
       unknown size to bucket boundary minus 1 (i.e., the maximum length in each
       bucket), and caller must ensure that the source `Dataset` does not contain
       any elements with length longer than `max(bucket_boundaries)`.
+    no_padding: `bool`, indicates whether to pad the batch features (features
+      need to be either of type `tf.SparseTensor` or of same shape).
 
   Returns:
     A `Dataset` transformation function, which can be passed to
@@ -199,7 +202,9 @@
 
     def batching_fn(bucket_id, grouped_dataset):
       """Batch elements in dataset."""
-      batch_size = batch_sizes[bucket_id]
+      batch_size = window_size_fn(bucket_id)
+      if no_padding:
+        return grouped_dataset.batch(batch_size)
       none_filler = None
       if pad_to_bucket_boundary:
         err_msg = ("When pad_to_bucket_boundary=True, elements must have "
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 38c0a09..92d4251 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -220,6 +220,7 @@
     if weights is None:
       # Select inputs with uniform probability.
       logits = [[1.0] * num_datasets]
+
     else:
       # Use the given `weights` as the probability of choosing the respective
       # input.
@@ -245,8 +246,11 @@
       return array_ops.squeeze(
           stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
 
-    selector_input = random_ops.RandomDataset(seed).batch(2).map(
-        select_dataset_constant_logits)
+    selector_input = dataset_ops.MapDataset(
+        random_ops.RandomDataset(seed).batch(2),
+        select_dataset_constant_logits,
+        use_inter_op_parallelism=False)
+
   else:
     # Use each element of the given `weights` dataset as the probability of
     # choosing the respective input.
@@ -259,9 +263,12 @@
       return array_ops.squeeze(
           stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
 
-    selector_input = dataset_ops.Dataset.zip(
-        (logits_ds, random_ops.RandomDataset(seed).batch(2)
-        )).map(select_dataset_varying_logits)
+    logits_and_seeds = dataset_ops.Dataset.zip(
+        (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+    selector_input = dataset_ops.MapDataset(
+        logits_and_seeds,
+        select_dataset_varying_logits,
+        use_inter_op_parallelism=False)
 
   return _DirectedInterleaveDataset(selector_input, datasets)
 
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6..3d0d099 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@
 
   elems = [ops.convert_to_tensor(e) for e in elems]
   output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
-  if not all(s.is_fully_defined() for s in output_shapes):
-    raise ValueError("All fn output shapes must be fully defined.")
   return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index fa1b851..4114b62 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -46,6 +46,21 @@
   return _apply_fn
 
 
+def model():
+  """A transformation that models performance.
+
+  Returns:
+    A `Dataset` transformation function, which can be passed to
+    @{tf.data.Dataset.apply}.
+  """
+
+  def _apply_fn(dataset):
+    """Function from `Dataset` to `Dataset` that applies the transformation."""
+    return _ModelDataset(dataset)
+
+  return _apply_fn
+
+
 def optimize(optimizations=None):
   """A transformation that applies optimizations.
 
@@ -97,6 +112,32 @@
     return self._input_dataset.output_types
 
 
+class _ModelDataset(dataset_ops.Dataset):
+  """A `Dataset` that acts as an identity, and models performance."""
+
+  def __init__(self, input_dataset):
+    """See `optimize()` for details."""
+    super(_ModelDataset, self).__init__()
+    self._input_dataset = input_dataset
+
+  def _as_variant_tensor(self):
+    return gen_dataset_ops.model_dataset(
+        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
+        **dataset_ops.flat_structure(self))
+
+  @property
+  def output_classes(self):
+    return self._input_dataset.output_classes
+
+  @property
+  def output_shapes(self):
+    return self._input_dataset.output_shapes
+
+  @property
+  def output_types(self):
+    return self._input_dataset.output_types
+
+
 class _OptimizeDataset(dataset_ops.Dataset):
   """A `Dataset` that acts as an identity, and applies optimizations."""
 
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 7f09ba7..4c46678 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -499,7 +499,8 @@
   # indefinitely, and all batches will be full-sized.
   dataset = dataset.batch(batch_size=batch_size,
                           drop_remainder=num_epochs is None)
-  dataset = dataset.map(map_fn)
+  dataset = dataset_ops.MapDataset(
+      dataset, map_fn, use_inter_op_parallelism=False)
   dataset = dataset.prefetch(prefetch_buffer_size)
 
   return dataset
@@ -778,7 +779,8 @@
 
   # Extract values if the `Example` tensors are stored as key-value tuples.
   if dataset.output_types == (dtypes.string, dtypes.string):
-    dataset = dataset.map(lambda _, v: v)
+    dataset = dataset_ops.MapDataset(
+        dataset, lambda _, v: v, use_inter_op_parallelism=False)
 
   # Apply dataset repeat and shuffle transformations.
   dataset = _maybe_shuffle_and_repeat(
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 30e1992..91a27f9 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -76,7 +76,7 @@
 ```python
 model.compile(loss='mean_squared_error',
               optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
-              distribute=strategy)
+              distribute=distribution)
 ```
 
 To train the model we call Keras `fit` API using the input dataset that we
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index c524d8b..87f76ea 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -708,19 +708,32 @@
     ],
 )
 
-cuda_py_test(
-    name = "keras_test",
+py_library(
+    name = "keras_test_lib",
+    testonly = 1,
     srcs = ["keras_test.py"],
-    additional_deps = [
-        "//third_party/py/numpy",
+    deps = [
+        ":combinations",
         "//tensorflow/contrib/distribute/python:mirrored_strategy",
+        "//tensorflow/contrib/distribute/python:tpu_strategy",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:training",
         "//tensorflow/python/estimator:estimator_py",
         "//tensorflow/python/keras",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+cuda_py_test(
+    name = "keras_test",
+    srcs = ["keras_test.py"],
+    additional_deps = [
+        ":keras_test_lib",
     ],
     tags = [
         "multi_and_single_gpu",
+        "no_pip",
         "no_windows_gpu",
         "notsan",
     ],
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 4fa8aa0..77079d0 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -229,6 +229,8 @@
     if not session_config or not self._cluster_spec:
       return
 
+    session_config.isolate_session_state = True
+
     assert self._task_type
     assert self._task_id is not None
 
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 2301ba9..244d1fc 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -50,10 +50,12 @@
 from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
 from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
 from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
+from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2
 from tensorflow.contrib.optimizer_v2 import adam as adam_v2
 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
+from tensorflow.python.training import adagrad
 from tensorflow.python.training import adam
 from tensorflow.python.training import distribution_strategy_context
 from tensorflow.python.training import gradient_descent
@@ -328,6 +330,10 @@
     "TPU", lambda: tpu_lib.TPUStrategy(
         TPUClusterResolver(""), steps_per_run=5),
     required_tpu=True)
+tpu_strategy_one_step = NamedDistribution(
+    "TPU", lambda: tpu_lib.TPUStrategy(
+        TPUClusterResolver(""), steps_per_run=1),
+    required_tpu=True)
 # Note that we disable prefetching for testing since prefetching makes
 # the input non-deterministic.
 mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
@@ -343,17 +349,23 @@
 
 
 adam_optimizer_v1_fn = NamedObject(
-    "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
+    "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
 gradient_descent_optimizer_v1_fn = NamedObject(
     "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]
+adagrad_optimizer_v1_fn = NamedObject(
+    "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
+                 adagrad_optimizer_v1_fn]
 
 adam_optimizer_v2_fn = NamedObject(
-    "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1))
+    "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
 gradient_descent_optimizer_v2_fn = NamedObject(
     "GradientDescentV2",
     lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]
+adagrad_optimizer_v2_fn = NamedObject(
+    "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
+optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
+                 adagrad_optimizer_v2_fn]
 
 graph_and_eager_modes = ["graph", "eager"]
 
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index a20069c..a84ef04 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -58,13 +58,12 @@
   train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
   train_ds = train_ds.repeat()
   train_ds = train_ds.shuffle(100)
-  train_ds = train_ds.batch(64)
+  train_ds = train_ds.batch(64, drop_remainder=True)
 
   # eval dataset
   eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
   eval_ds = eval_ds.repeat()
-  eval_ds = eval_ds.shuffle(100)
-  eval_ds = eval_ds.batch(64)
+  eval_ds = eval_ds.batch(64, drop_remainder=True)
 
   return train_ds, eval_ds, input_shape
 
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index d39fd572..5f35e38 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -18,9 +18,12 @@
 from __future__ import print_function
 
 import os
+from absl.testing import parameterized
 import numpy as np
 
+from tensorflow.contrib.distribute.python import combinations
 from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
 from tensorflow.contrib.distribute.python import values
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
@@ -31,6 +34,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.ops.parsing_ops import gen_parsing_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.summary.writer import writer_cache
@@ -63,6 +67,32 @@
   return model
 
 
+def multi_inputs_multi_outputs_model():
+  input_a = keras.layers.Input(shape=(16,), name='input_a')
+  input_b = keras.layers.Input(shape=(16,), name='input_b')
+  input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+  dense = keras.layers.Dense(8, name='dense_1')
+
+  interm_a = dense(input_a)
+  # Read m
+  interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+  interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+  interm_b = dense(input_b)
+  merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+  output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+  output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+  model = keras.models.Model(
+      inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
+  model.compile(
+      loss='categorical_crossentropy',
+      optimizer=gradient_descent.GradientDescentOptimizer(0.001),
+      metrics={
+          'dense_2': 'categorical_accuracy',
+          'dense_3': 'categorical_accuracy'
+      })
+  return model
+
+
 def get_ds_train_input_fn():
   np.random.seed(_RANDOM_SEED)
   (x_train, y_train), _ = testing_utils.get_test_data(
@@ -91,6 +121,68 @@
   return dataset
 
 
+def get_multi_inputs_multi_outputs_data():
+  (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(16,),
+      num_classes=3,
+      random_seed=_RANDOM_SEED)
+  (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(16,),
+      num_classes=2,
+      random_seed=_RANDOM_SEED)
+  (m_train, _), (m_test, _) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(8,),
+      num_classes=2,
+      random_seed=_RANDOM_SEED)
+
+  c_train = keras.utils.to_categorical(c_train)
+  c_test = keras.utils.to_categorical(c_test)
+  d_train = keras.utils.to_categorical(d_train)
+  d_test = keras.utils.to_categorical(d_test)
+
+  train_data = {
+      'input_a': a_train,
+      'input_b': b_train,
+      'input_m': m_train,
+      'output_c': c_train,
+      'output_d': d_train
+  }
+  test_data = {
+      'input_a': a_test,
+      'input_b': b_test,
+      'input_m': m_test,
+      'output_c': c_test,
+      'output_d': d_test
+  }
+
+  return (train_data, test_data)
+
+
+def batch_wrapper(dataset, batch_size, distribution):
+  # TPUs currently require fully defined input shapes, drop_remainder ensures
+  # the input will have fully defined shapes.
+  if isinstance(distribution, tpu_strategy.TPUStrategy):
+    return dataset.batch(batch_size, drop_remainder=True)
+  else:
+    return dataset.batch(batch_size)
+
+
+def all_combinations():
+  return combinations.combine(
+      distribution=[combinations.default_strategy,
+                    combinations.one_device_strategy,
+                    combinations.mirrored_strategy_with_gpu_and_cpu,
+                    combinations.mirrored_strategy_with_two_gpus,
+                    combinations.tpu_strategy_one_step],
+      mode=['graph'])
+
+
 class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
 
   def setUp(self):
@@ -99,6 +191,8 @@
     gfile.MakeDirs(self._base_dir)
     self._config = run_config_lib.RunConfig(
         tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+    self._dist = mirrored_strategy.MirroredStrategy(
+        devices=['/device:GPU:0', '/device:GPU:1'])
 
   def tearDown(self):
     writer_cache.FileWriterCache.clear()
@@ -152,6 +246,53 @@
     writer_cache.FileWriterCache.clear()
     gfile.DeleteRecursively(self._config.model_dir)
 
+  def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+    train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+    def train_input_fn():
+      input_dict = {
+          'input_a': train_data['input_a'],
+          'input_b': train_data['input_b'],
+          'input_m': train_data['input_m'].astype(np.str)
+      }
+      output_dict = {
+          'dense_2': train_data['output_c'],
+          'dense_3': train_data['output_d']
+      }
+      return dataset_ops.Dataset.from_tensor_slices((input_dict,
+                                                     output_dict)).batch(16)
+
+    def eval_input_fn():
+      input_dict = {
+          'input_a': test_data['input_a'],
+          'input_b': test_data['input_b'],
+          'input_m': test_data['input_m'].astype(np.str)
+      }
+      output_dict = {
+          'dense_2': test_data['output_c'],
+          'dense_3': test_data['output_d']
+      }
+      return dataset_ops.Dataset.from_tensor_slices((input_dict,
+                                                     output_dict)).batch(16)
+
+    self.do_test_multi_inputs_multi_outputs_with_input_fn(
+        train_input_fn, eval_input_fn)
+
+  def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,
+                                                       eval_input_fn):
+    config = run_config_lib.RunConfig(
+        tf_random_seed=_RANDOM_SEED,
+        model_dir=self._base_dir,
+        train_distribute=self._dist)
+    with self.cached_session():
+      model = multi_inputs_multi_outputs_model()
+      est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
+      baseline_eval_results = est_keras.evaluate(
+          input_fn=eval_input_fn, steps=1)
+      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+      eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+      self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+
   def test_keras_optimizer_with_distribution_strategy(self):
     dist = mirrored_strategy.MirroredStrategy(
         devices=['/device:GPU:0', '/device:GPU:1'])
@@ -175,7 +316,7 @@
     gfile.DeleteRecursively(self._config.model_dir)
 
 
-class TestWithDistributionStrategy(test.TestCase):
+class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
 
   def test_validating_dataset_input_tensors_with_shape_mismatch(self):
     with self.cached_session():
@@ -215,7 +356,7 @@
           distributed_training_utils.validate_distributed_dataset_inputs(
               strategy, x, y)
 
-  def test_calling_model_on_same_dataset(self):
+  def test_calling_model_with_numpy_arrays(self):
     with self.cached_session():
       x = keras.layers.Input(shape=(3,), name='input')
       y = keras.layers.Dense(4, name='dense')(x)
@@ -228,11 +369,44 @@
                                                      '/device:GPU:0'])
       model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
 
+      inputs = np.zeros((64, 3), dtype=np.float32)
+      targets = np.zeros((64, 4), dtype=np.float32)
+
+      # Call fit with validation data
+      model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
+                validation_data=(inputs, targets))
+
+      # TODO(anjalisridhar): We need tests for when the batch size and steps are
+      # smaller and results in a 0 batch_size and steps value.
+      model.evaluate(inputs, targets)
+      # with steps
+      model.evaluate(inputs, targets, steps=2)
+      # with batch_size
+      model.evaluate(inputs, targets, batch_size=8)
+
+      model.predict(inputs)
+      # with steps
+      model.predict(inputs, steps=2)
+      # with batch_size
+      model.predict(inputs, batch_size=8)
+
+  @combinations.generate(all_combinations())
+  def test_calling_model_on_same_dataset(self, distribution):
+    with self.cached_session():
+      x = keras.layers.Input(shape=(3,), name='input')
+      y = keras.layers.Dense(4, name='dense')(x)
+      model = keras.Model(x, y)
+
+      optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+      loss = 'mse'
+      metrics = ['mae']
+      model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
+
       inputs = np.zeros((10, 3), dtype=np.float32)
       targets = np.zeros((10, 4), dtype=np.float32)
       dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
       dataset = dataset.repeat(100)
-      dataset = dataset.batch(10)
+      dataset = batch_wrapper(dataset, 10, distribution)
 
       # Call fit with validation data
       model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -241,6 +415,9 @@
                 validation_data=dataset, validation_steps=2)
       model.predict(dataset, steps=2)
 
+  # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
+  # as clone_model's input_tensors argument only seems to accept list and not
+  # tuples or dict.
   def test_fit_with_tuple_and_dict_dataset_inputs(self):
     with self.cached_session():
       a = keras.layers.Input(shape=(3,), name='input_a')
@@ -282,7 +459,8 @@
 
       model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
 
-  def test_fit_eval_and_predict_methods_on_dataset(self):
+  @combinations.generate(all_combinations())
+  def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
     with self.cached_session():
       x = keras.layers.Input(shape=(3,), name='input')
       y = keras.layers.Dense(4, name='dense')(x)
@@ -291,16 +469,13 @@
       optimizer = gradient_descent.GradientDescentOptimizer(0.001)
       loss = 'mse'
       metrics = ['mae']
-      strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
-                                                     '/device:CPU:0'])
-
-      model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+      model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
 
       inputs = np.zeros((10, 3), dtype=np.float32)
       targets = np.zeros((10, 4), dtype=np.float32)
       dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
       dataset = dataset.repeat(100)
-      dataset = dataset.batch(10)
+      dataset = batch_wrapper(dataset, 10, distribution)
 
       model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
       model.evaluate(dataset, steps=2, verbose=1)
@@ -446,8 +621,7 @@
       dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
       dataset = dataset.repeat(100)
 
-      with self.assertRaisesRegexp(ValueError,
-                                   'expected input to have 2 dimensions'):
+      with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
         model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
 
       # Wrong input shape
@@ -497,6 +671,8 @@
 
 class LossMaskingWithDistributionStrategyTest(test.TestCase):
 
+  # TODO(priyag): Enable all strategies for this test. Currently it does not
+  # work for TPU due to some invalid datatype.
   def test_masking(self):
     with self.cached_session():
       np.random.seed(1337)
@@ -520,24 +696,25 @@
       self.assertEqual(hist.history['loss'][0], 0)
 
 
-class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+class NormalizationLayerWithDistributionStrategyTest(
+    test.TestCase, parameterized.TestCase):
 
-  def test_batchnorm_correctness(self):
+  @combinations.generate(all_combinations())
+  def test_batchnorm_correctness(self, distribution):
     with self.cached_session():
       model = keras.models.Sequential()
       norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
       model.add(norm)
-      strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
-                                                     '/device:GPU:0'])
       model.compile(loss='mse',
                     optimizer=gradient_descent.GradientDescentOptimizer(0.01),
-                    distribute=strategy)
+                    distribute=distribution)
 
       # centered on 5.0, variance 10.0
       x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+      x = x.astype('float32')
       dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
       dataset = dataset.repeat(100)
-      dataset = dataset.batch(32)
+      dataset = batch_wrapper(dataset, 32, distribution)
 
       model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
       out = model.predict(dataset, steps=2)
@@ -547,9 +724,11 @@
       np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
 
 
-class CorrectnessWithDistributionStrategyTest(test.TestCase):
+class CorrectnessWithDistributionStrategyTest(test.TestCase,
+                                              parameterized.TestCase):
 
-  def test_correctness(self):
+  @combinations.generate(all_combinations())
+  def test_correctness(self, distribution):
     with self.cached_session():
       keras.backend.set_image_data_format('channels_last')
       num_samples = 10000
@@ -558,43 +737,43 @@
       x_train = x_train.astype('float32')
       y_train = y_train.astype('float32')
 
-      model = keras.Sequential()
-      model.add(keras.layers.Dense(1, input_shape=(1,)))
+      def fit_and_predict(with_distribution=None):
+        model = keras.Sequential()
+        model.add(keras.layers.Dense(1, input_shape=(1,)))
+        model.compile(
+            loss=keras.losses.mean_squared_error,
+            optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+            distribute=with_distribution)
 
-      # With DistributionStrategy
-      dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
-      dataset_with = dataset_with.batch(32)
-      strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
-                                                             '/device:GPU:0'])
-
-      model.compile(loss=keras.losses.mean_squared_error,
-                    optimizer=gradient_descent.GradientDescentOptimizer(0.5),
-                    distribute=strategy)
-      model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
-      wts_with_ds = model.get_weights()
-
-      x_predict = [[1], [2], [3], [4]]
-      predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
-                                                                     x_predict))
-      predict_dataset_with = predict_dataset_with.batch(2)
-      predict_with_ds = model.predict(predict_dataset_with, steps=1)
-      predict_with_ds = np.reshape(predict_with_ds, (4, 1))
-
-      # Without DistributionStrategy
-      dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+        batch_size = 64
+        if with_distribution:
+          batch_size //= with_distribution.num_towers
+        train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
                                                                 y_train))
-      dataset_without = dataset_without.batch(64)
+        train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+        # Running only 100 steps instead of the full dataset to keep test
+        # duration small.
+        model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
 
-      model.compile(loss=keras.losses.mean_squared_error,
-                    optimizer=gradient_descent.GradientDescentOptimizer(0.5))
-      model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
-      wts_without_ds = model.get_weights()
+        weights = model.get_weights()
 
-      x_predict = [[1], [2], [3], [4]]
-      predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
-          x_predict, x_predict))
-      predict_dataset_without = predict_dataset_without.batch(4)
-      predict_without_ds = model.predict(predict_dataset_without, steps=1)
+        x_predict = [[1.], [2.], [3.], [4.]]
+        predict_batch_size = 4
+        if with_distribution:
+          predict_batch_size //= with_distribution.num_towers
+        predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
+                                                                  x_predict))
+        predict_dataset = batch_wrapper(predict_dataset,
+                                        predict_batch_size, distribution)
+        predict_result = model.predict(predict_dataset, steps=1)
+        predict_result = np.reshape(predict_result, (4, 1))
+
+        return weights, predict_result
+
+      wts_with_ds, predict_with_ds = fit_and_predict(
+          with_distribution=distribution)
+      wts_without_ds, predict_without_ds = fit_and_predict(
+          with_distribution=None)
 
       # Verify that the weights are the same within some limits of tolerance.
       np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
@@ -603,5 +782,8 @@
       np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
 
 
+# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1.
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index bdac4fb..ba147e7 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -183,6 +183,10 @@
                 "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
                 "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
                 "dense/bias/Adam_1"
+            ],
+            "Adagrad": [
+                "dense/kernel/Adagrad", "dense/kernel",
+                "dense/bias/Adagrad", "dense/bias"
             ]
         }
         variables = variables_map[optimizer_fn().get_name()]
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index d1235b7..0c6805d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -572,6 +572,10 @@
                 task_type=None,
                 task_id=None):
     del task_type, task_id
+
+    if session_config:
+      session_config.isolate_session_state = True
+
     if cluster_spec:
       self._initialize_multi_worker(self._num_gpus, cluster_spec)
 
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 88d7768..1125d02 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -412,6 +412,8 @@
     if not session_config or not self._cluster_spec:
       return
 
+    session_config.isolate_session_state = False
+
     assert self._cluster_spec
     assert self._task_type
     assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
index bb10b54..1679910 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -55,14 +55,14 @@
     next_element = iterator.get_next()
 
     output = []
+    # TODO(rohanj): Modify test to go till the end of the dataset when we
+    # switch to MultiDeviceIterator.
     with self.cached_session() as sess:
-      for _ in range(5):
+      for _ in range(4):
         result = sess.run(next_element)
         self.assertEqual(2, len(result))
         output.extend(result)
-      self.assertEquals(set(range(10)), set(output))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
+      self.assertEquals(set(range(8)), set(output))
 
   def testPrefetchToTwoDevicesWithReinit(self):
     if not test_util.is_gpu_available():
@@ -75,14 +75,14 @@
     iterator = device_dataset.make_initializable_iterator()
     next_element = iterator.get_next()
 
+    # TODO(rohanj): Modify test to go till the end of the dataset when we
+    # switch to MultiDeviceIterator.
     with self.cached_session() as sess:
       sess.run(iterator.initializer)
-      for _ in range(5):
-        sess.run(next_element)
-      with self.assertRaises(errors.OutOfRangeError):
+      for _ in range(4):
         sess.run(next_element)
       sess.run(iterator.initializer)
-      for _ in range(5):
+      for _ in range(4):
         sess.run(next_element)
 
 
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 32d7444..6ba8397 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -310,4 +310,18 @@
   def get_host_cpu_device(self, host_id):
     if self._tpu_cluster_resolver.get_master() in ('', 'local'):
       return '/replica:0/task:0/device:CPU:0'
-    return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
+    job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+    return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
+
+  def configure(self,
+                session_config=None,
+                cluster_spec=None,
+                task_type=None,
+                task_id=None):
+    del cluster_spec, task_type, task_id
+    if session_config:
+      session_config.isolate_session_state = True
+      cluster_spec = self._tpu_cluster_resolver.cluster_spec()
+      if cluster_spec:
+        session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 97c53ae..9aadc63 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -166,6 +166,7 @@
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
+    tags = ["notap"],
 )
 
 cuda_py_test(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
index a7bd514..1e36b7f 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
@@ -20,8 +20,8 @@
 
 import numpy as np
 
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 196cc41..1337049 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -22,7 +22,6 @@
 from scipy import stats
 
 from tensorflow.contrib import distributions
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops import bijectors
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -30,6 +29,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.platform import test
 
 bs = bijectors
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
index 25f2945..ba31697 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
 from tensorflow.python.framework import dtypes
@@ -29,6 +28,7 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.util import deprecation
 
 
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 6959b3e..b4ad33c 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -18,7 +18,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib import linalg
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import smart_cond
@@ -27,6 +26,7 @@
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.ops.distributions import distribution as distribution_lib
 
 # The following two lines are redundant, in a sense. The first enables
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index d840180..74d9d04 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
 from tensorflow.python.framework import ops
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.util import deprecation
 
 
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d911094..c6a23e4 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
 from tensorflow.python.framework import ops
 from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.util import deprecation
 
 
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index f1accaa..49b9de0 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -21,7 +21,6 @@
 import math
 import numpy as np
 
-from tensorflow.contrib import linalg
 from tensorflow.contrib.distributions.python.ops import distribution_util
 from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
 from tensorflow.python.framework import constant_op
@@ -36,6 +35,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.linalg import linalg
 from tensorflow.python.util import deprecation
 
 __all__ = [
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py
index 7d2274d..48d093e 100644
--- a/tensorflow/contrib/eager/python/evaluator_test.py
+++ b/tensorflow/contrib/eager/python/evaluator_test.py
@@ -117,7 +117,7 @@
     self.assertEqual(6.0, results["mean"].numpy())
 
   def testDatasetGraph(self):
-    with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+    with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
       e = SimpleEvaluator(IdentityModel())
       ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
       init_op, call_op, results_op = e.evaluate_on_dataset(ds)
@@ -126,7 +126,7 @@
       self.assertEqual(6.0, results["mean"])
 
   def testWriteSummariesGraph(self):
-    with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+    with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
       e = SimpleEvaluator(IdentityModel())
       ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
       training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 529c99b..3acecd2 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -1056,7 +1056,7 @@
         "\n",
         "        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
         "\n",
-        "        predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
         "        result.append(index_word[predicted_id])\n",
         "\n",
         "        if index_word[predicted_id] == '<end>':\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 40bc098..e0d5e49 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -610,7 +610,7 @@
         "\n",
         "    # using a multinomial distribution to predict the word returned by the model\n",
         "    predictions = predictions / temperature\n",
-        "    predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+        "    predicted_id = tf.argmax(predictions[0]).numpy()\n",
         "    \n",
         "    # We pass the predicted word as the next input to the model\n",
         "    # along with the previous hidden state\n",
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index f1e1f99..560fc8c 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -677,7 +677,7 @@
         "        attention_weights = tf.reshape(attention_weights, (-1, ))\n",
         "        attention_plot[t] = attention_weights.numpy()\n",
         "\n",
-        "        predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+        "        predicted_id = tf.argmax(predictions[0]).numpy()\n",
         "\n",
         "        result += targ_lang.idx2word[predicted_id] + ' '\n",
         "\n",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
index fabd7b3..750bbc6 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
@@ -23,4 +23,4 @@
   https://en.wikipedia.org/wiki/List_of_colors:_N-Z
 
 This example was adapted from
-  https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
+  https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
deleted file mode 100644
index 638c57d..0000000
--- a/tensorflow/contrib/eager/python/examples/scan/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-licenses(["notice"])  # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-cuda_py_test(
-    name = "scan_test",
-    size = "small",
-    srcs = ["scan_test.py"],
-    additional_deps = [
-        "//third_party/py/numpy",
-        "//tensorflow:tensorflow_py",
-    ],
-)
-
-cuda_py_test(
-    name = "scan_graph_test",
-    size = "small",
-    srcs = ["scan_graph_test.py"],
-    additional_deps = [
-        "//third_party/py/numpy",
-        "//tensorflow:tensorflow_py",
-    ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
deleted file mode 100644
index d4b8c89..0000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under graph mode execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
-  def runScan(self, n):
-    elems = np.arange(n)
-    start_time = time.time()
-    sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
-    with tf.Session() as sess:
-      sess.run(sum_op)
-    wall_time = time.time() - start_time
-
-    self.report_benchmark(
-        name='scan',
-        iters=n,
-        wall_time=wall_time)
-
-  def benchmarkScan16000(self):
-    self.runScan(16000)
-
-  def benchmarkScan32000(self):
-    self.runScan(32000)
-
-  def benchmarkScan64000(self):
-    self.runScan(64000)
-
-  def benchmarkScan128000(self):
-    self.runScan(128000)
-
-if __name__ == '__main__':
-  tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
deleted file mode 100644
index a02fc24..0000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under eager execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
-  def runScan(self, n):
-    elems = np.arange(n)
-    start_time = time.time()
-    _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
-    wall_time = time.time() - start_time
-
-    self.report_benchmark(
-        name='scan',
-        iters=n,
-        wall_time=wall_time)
-
-  def benchmarkScan16000(self):
-    self.runScan(16000)
-
-  def benchmarkScan32000(self):
-    self.runScan(32000)
-
-  def benchmarkScan64000(self):
-    self.runScan(64000)
-
-  def benchmarkScan128000(self):
-    self.runScan(128000)
-
-
-if __name__ == '__main__':
-  tf.enable_eager_execution()
-  tf.test.main()
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index dcc7b71..9d2d172 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -216,7 +216,7 @@
     self.assertEqual(m1.numer.name, "has_space/numer:0")
 
   def testGraphWithPlaceholder(self):
-    with context.graph_mode(), self.test_session() as sess:
+    with context.graph_mode(), self.cached_session() as sess:
       m = metrics.Mean()
       p = array_ops.placeholder(dtypes.float32)
       accumulate = m(p)
@@ -309,7 +309,7 @@
     self.assertTrue(old_numer is m.numer)
 
   def testMetricsChain(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       m1 = metrics.Mean()
       m2 = metrics.Mean(name="m2")
       update_m2 = m2(3.0)
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df..6db311d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@
         ":boosted_trees",
         ":dnn",
         ":dnn_linear_combined",
+        ":dnn_with_layer_annotations",
         ":early_stopping",
         ":export",
         ":exporter",
@@ -127,6 +128,61 @@
 )
 
 py_library(
+    name = "dnn_with_layer_annotations",
+    srcs = ["python/estimator/dnn_with_layer_annotations.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:nn",
+        "//tensorflow/python:partitioned_variables",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python/estimator",
+        "//tensorflow/python/estimator:head",
+        "//tensorflow/python/estimator:model_fn",
+        "//tensorflow/python/estimator:optimizers",
+        "//tensorflow/python/feature_column",
+        "//tensorflow/python/ops/losses",
+        "//tensorflow/python/saved_model:utils",
+    ],
+)
+
+py_test(
+    name = "dnn_with_layer_annotations_test",
+    size = "medium",
+    srcs = ["python/estimator/dnn_with_layer_annotations_test.py"],
+    shard_count = 4,
+    srcs_version = "PY2AND3",
+    tags = [
+        "no_pip",
+        "notsan",  # b/67510291
+    ],
+    deps = [
+        ":dnn_with_layer_annotations",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:data_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:parsing_ops",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python/estimator:dnn",
+        "//tensorflow/python/estimator:dnn_testing_utils",
+        "//tensorflow/python/estimator:export_export",
+        "//tensorflow/python/estimator:numpy_io",
+        "//tensorflow/python/estimator:pandas_io",
+        "//tensorflow/python/estimator:prediction_keys",
+        "//tensorflow/python/feature_column",
+        "@six_archive//:six",
+    ],
+)
+
+py_library(
     name = "dnn_linear_combined",
     srcs = ["python/estimator/dnn_linear_combined.py"],
     srcs_version = "PY2AND3",
@@ -446,6 +502,7 @@
         "//tensorflow/python/estimator",
         "//tensorflow/python/estimator:head",
         "//tensorflow/python/estimator:optimizers",
+        "//tensorflow/python/ops/losses",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 258860f..78914ec 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -22,6 +22,7 @@
 from tensorflow.contrib.estimator.python.estimator.baseline import *
 from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
 from tensorflow.contrib.estimator.python.estimator.dnn import *
+from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import *
 from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
 from tensorflow.contrib.estimator.python.estimator.early_stopping import *
 from tensorflow.contrib.estimator.python.estimator.export import *
@@ -76,6 +77,8 @@
     'build_raw_supervised_input_receiver_fn',
     'build_supervised_input_receiver_fn_from_input_fn',
     'SavedModelEstimator'
+    'DNNClassifierWithLayerAnnotations',
+    'DNNRegressorWithLayerAnnotations',
 ]
 
 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
new file mode 100644
index 0000000..152431d
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -0,0 +1,434 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deep Neural Network estimators with layer annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import pickle
+
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.saved_model import utils as saved_model_utils
+
+
+class LayerAnnotationsCollectionNames(object):
+  """Names for the collections containing the annotations."""
+
+  UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features'
+  PROCESSED_FEATURES = 'layer_annotatons/processed_features'
+  FEATURE_COLUMNS = 'layer_annotations/feature_columns'
+
+  @classmethod
+  def keys(cls, collection_name):
+    return '%s/keys' % collection_name
+
+  @classmethod
+  def values(cls, collection_name):
+    return '%s/values' % collection_name
+
+
+def serialize_feature_column(feature_column):
+  if isinstance(feature_column, feature_column_lib._EmbeddingColumn):  # pylint: disable=protected-access
+    # We can't pickle nested functions, and we don't need the value of
+    # layer_creator in most cases anyway, so just discard its value.
+    args = feature_column._asdict()
+    args['layer_creator'] = None
+    temp = type(feature_column)(**args)
+    return pickle.dumps(temp)
+  return pickle.dumps(feature_column)
+
+
+def _to_any_wrapped_tensor_info(tensor):
+  """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`."""
+  any_buf = Any()
+  tensor_info = saved_model_utils.build_tensor_info(tensor)
+  any_buf.Pack(tensor_info)
+  return any_buf
+
+
+def make_input_layer_with_layer_annotations(original_input_layer, mode):
+  """Make an input_layer replacement function that adds layer annotations."""
+
+  def input_layer_with_layer_annotations(features,
+                                         feature_columns,
+                                         weight_collections=None,
+                                         trainable=True,
+                                         cols_to_vars=None,
+                                         cols_to_output_tensors=None):
+    """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+    Generally a single example in training data is described with
+    FeatureColumns.
+    At the first layer of the model, this column oriented data should be
+    converted
+    to a single `Tensor`.
+
+    This is like tf.feature_column.input_layer, except with added
+    Integrated-Gradient annotations.
+
+    Args:
+      features: A mapping from key to tensors. `_FeatureColumn`s look up via
+        these keys. For example `numeric_column('price')` will look at 'price'
+        key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+        on corresponding `_FeatureColumn`.
+      feature_columns: An iterable containing the FeatureColumns to use as
+        inputs to your model. All items should be instances of classes derived
+        from `_DenseColumn` such as `numeric_column`, `embedding_column`,
+        `bucketized_column`, `indicator_column`. If you have categorical
+        features, you can wrap them with an `embedding_column` or
+        `indicator_column`.
+      weight_collections: A list of collection names to which the Variable will
+        be added. Note that variables will also be added to collections
+        `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+      trainable: If `True` also add the variable to the graph collection
+        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+      cols_to_vars: If not `None`, must be a dictionary that will be filled with
+        a mapping from `_FeatureColumn` to list of `Variable`s.  For example,
+        after the call, we might have cols_to_vars = {_EmbeddingColumn(
+        categorical_column=_HashedCategoricalColumn( key='sparse_feature',
+        hash_bucket_size=5, dtype=tf.string), dimension=10): [<tf.Variable
+        'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
+          shape=(5, 10)]} If a column creates no variables, its value will be an
+          empty list.
+      cols_to_output_tensors: If not `None`, must be a dictionary that will be
+        filled with a mapping from '_FeatureColumn' to the associated output
+        `Tensor`s.
+
+    Returns:
+      A `Tensor` which represents input layer of a model. Its shape
+      is (batch_size, first_layer_dimension) and its dtype is `float32`.
+      first_layer_dimension is determined based on given `feature_columns`.
+
+    Raises:
+      ValueError: features and feature_columns have different lengths.
+    """
+
+    local_cols_to_output_tensors = {}
+    input_layer = original_input_layer(
+        features=features,
+        feature_columns=feature_columns,
+        weight_collections=weight_collections,
+        trainable=trainable,
+        cols_to_vars=cols_to_vars,
+        cols_to_output_tensors=local_cols_to_output_tensors)
+
+    if cols_to_output_tensors is not None:
+      cols_to_output_tensors = local_cols_to_output_tensors
+
+    if mode and mode == model_fn.ModeKeys.PREDICT:
+      # Only annotate in PREDICT mode.
+
+      # Annotate features.
+      # These are the parsed Tensors, before embedding.
+
+      # Only annotate features used by FeatureColumns.
+      # We figure which ones are used by FeatureColumns by creating a parsing
+      # spec and looking at the keys.
+      spec = feature_column_lib.make_parse_example_spec(feature_columns)
+      for key in spec.keys():
+        tensor = features[key]
+        ops.add_to_collection(
+            LayerAnnotationsCollectionNames.keys(
+                LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+        ops.add_to_collection(
+            LayerAnnotationsCollectionNames.values(
+                LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+            _to_any_wrapped_tensor_info(tensor))
+
+      # Annotate feature columns.
+      for column in feature_columns:
+        # TODO(cyfoo): Find a better way to serialize and deserialize
+        # _FeatureColumn.
+        ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+                              serialize_feature_column(column))
+
+      for column, tensor in local_cols_to_output_tensors.items():
+        ops.add_to_collection(
+            LayerAnnotationsCollectionNames.keys(
+                LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+            column.name)
+        ops.add_to_collection(
+            LayerAnnotationsCollectionNames.values(
+                LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+            _to_any_wrapped_tensor_info(tensor))
+
+    return input_layer
+
+  return input_layer_with_layer_annotations
+
+
+@contextlib.contextmanager
+def _monkey_patch(module, function, replacement):
+  old_function = getattr(module, function)
+  setattr(module, function, replacement)
+  yield
+  setattr(module, function, old_function)
+
+
+def DNNClassifierWithLayerAnnotations(  # pylint: disable=invalid-name
+    hidden_units,
+    feature_columns,
+    model_dir=None,
+    n_classes=2,
+    weight_column=None,
+    label_vocabulary=None,
+    optimizer='Adagrad',
+    activation_fn=nn.relu,
+    dropout=None,
+    input_layer_partitioner=None,
+    config=None,
+    warm_start_from=None,
+    loss_reduction=losses.Reduction.SUM):
+  """A classifier for TensorFlow DNN models with layer annotations.
+
+  This classifier is fuctionally identical to estimator.DNNClassifier as far as
+  training and evaluating models is concerned. The key difference is that this
+  classifier adds additional layer annotations, which can be used for computing
+  Integrated Gradients.
+
+  Integrated Gradients is a method for attributing a classifier's predictions
+  to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+  instance, the method assigns attribution scores to individual features in
+  proportion to the feature's importance to the classifier's prediction.
+
+  See estimator.DNNClassifer for example code for training and evaluating models
+  using this classifier.
+
+  This classifier is checkpoint-compatible with estimator.DNNClassifier and
+  therefore the following should work seamlessly:
+
+  # Instantiate ordinary estimator as usual.
+  estimator = tf.estimator.DNNClassifier(
+    config, feature_columns, hidden_units, ...)
+
+  # Train estimator, export checkpoint.
+  tf.estimator.train_and_evaluate(estimator, ...)
+
+  # Instantiate estimator with annotations with the same configuration as the
+  # ordinary estimator.
+  estimator_with_annotations = (
+    tf.contrib.estimator.DNNClassifierWithLayerAnnotations(
+      config, feature_columns, hidden_units, ...))
+
+  # Call export_savedmodel with the same arguments as the ordinary estimator,
+  # using the checkpoint produced for the ordinary estimator.
+  estimator_with_annotations.export_saved_model(
+    export_dir_base, serving_input_receiver, ...
+    checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+  Args:
+    hidden_units: Iterable of number hidden units per layer. All layers are
+      fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+      one has 32.
+    feature_columns: An iterable containing all the feature columns used by the
+      model. All items in the set should be instances of classes derived from
+      `_FeatureColumn`.
+    model_dir: Directory to save model parameters, graph and etc. This can also
+      be used to load checkpoints from the directory into a estimator to
+      continue training a previously saved model.
+    n_classes: Number of label classes. Defaults to 2, namely binary
+      classification. Must be > 1.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+      weights. It is used to down weight or boost examples during training. It
+      will be multiplied by the loss of the example. If it is a string, it is
+      used as a key to fetch weight tensor from the `features`. If it is a
+      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+      weight_column.normalizer_fn is applied on it to get weight tensor.
+    label_vocabulary: A list of strings represents possible label values. If
+      given, labels must be string type and have any value in
+      `label_vocabulary`. If it is not given, that means labels are already
+      encoded as integer or float within [0, 1] for `n_classes=2` and encoded as
+      integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there
+      will be errors if vocabulary is not provided and labels are string.
+    optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+      to Adagrad optimizer.
+    activation_fn: Activation function applied to each layer. If `None`, will
+      use `tf.nn.relu`.
+    dropout: When not `None`, the probability we will drop out a given
+      coordinate.
+    input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+      `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+    config: `RunConfig` object to configure the runtime settings.
+    warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+      `WarmStartSettings` object to fully configure warm-starting.  If the
+      string filepath is provided instead of a `WarmStartSettings`, then all
+      weights are warm-started, and it is assumed that vocabularies and Tensor
+      names are unchanged.
+    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+      reduce training loss over batch. Defaults to `SUM`.
+
+  Returns:
+    DNNClassifier with layer annotations.
+  """
+
+  original = dnn.DNNClassifier(
+      hidden_units=hidden_units,
+      feature_columns=feature_columns,
+      model_dir=model_dir,
+      n_classes=n_classes,
+      weight_column=weight_column,
+      label_vocabulary=label_vocabulary,
+      optimizer=optimizer,
+      activation_fn=activation_fn,
+      dropout=dropout,
+      input_layer_partitioner=input_layer_partitioner,
+      config=config,
+      warm_start_from=warm_start_from,
+      loss_reduction=loss_reduction)
+
+  def _model_fn(features, labels, mode, config):
+    with _monkey_patch(
+        feature_column_lib, 'input_layer',
+        make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+                                                mode)):
+      return original.model_fn(features, labels, mode, config)
+
+  return estimator.Estimator(
+      model_fn=_model_fn,
+      model_dir=model_dir,
+      config=config,
+      warm_start_from=warm_start_from)
+
+
+def DNNRegressorWithLayerAnnotations(  # pylint: disable=invalid-name
+    hidden_units,
+    feature_columns,
+    model_dir=None,
+    label_dimension=1,
+    weight_column=None,
+    optimizer='Adagrad',
+    activation_fn=nn.relu,
+    dropout=None,
+    input_layer_partitioner=None,
+    config=None,
+    warm_start_from=None,
+    loss_reduction=losses.Reduction.SUM,
+):
+  """A regressor for TensorFlow DNN models with layer annotations.
+
+  This regressor is fuctionally identical to estimator.DNNRegressor as far as
+  training and evaluating models is concerned. The key difference is that this
+  classifier adds additional layer annotations, which can be used for computing
+  Integrated Gradients.
+
+  Integrated Gradients is a method for attributing a classifier's predictions
+  to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+  instance, the method assigns attribution scores to individual features in
+  proportion to the feature's importance to the classifier's prediction.
+
+  See estimator.DNNRegressor for example code for training and evaluating models
+  using this regressor.
+
+  This regressor is checkpoint-compatible with estimator.DNNRegressor and
+  therefore the following should work seamlessly:
+
+  # Instantiate ordinary estimator as usual.
+  estimator = tf.estimator.DNNRegressor(
+    config, feature_columns, hidden_units, ...)
+
+  # Train estimator, export checkpoint.
+  tf.estimator.train_and_evaluate(estimator, ...)
+
+  # Instantiate estimator with annotations with the same configuration as the
+  # ordinary estimator.
+  estimator_with_annotations = (
+    tf.contrib.estimator.DNNRegressorWithLayerAnnotations(
+      config, feature_columns, hidden_units, ...))
+
+  # Call export_savedmodel with the same arguments as the ordinary estimator,
+  # using the checkpoint produced for the ordinary estimator.
+  estimator_with_annotations.export_saved_model(
+    export_dir_base, serving_input_receiver, ...
+    checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+  Args:
+    hidden_units: Iterable of number hidden units per layer. All layers are
+      fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+      one has 32.
+    feature_columns: An iterable containing all the feature columns used by the
+      model. All items in the set should be instances of classes derived from
+      `_FeatureColumn`.
+    model_dir: Directory to save model parameters, graph and etc. This can also
+      be used to load checkpoints from the directory into a estimator to
+      continue training a previously saved model.
+    label_dimension: Number of regression targets per example. This is the size
+      of the last dimension of the labels and logits `Tensor` objects
+      (typically, these have shape `[batch_size, label_dimension]`).
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+      weights. It is used to down weight or boost examples during training. It
+      will be multiplied by the loss of the example. If it is a string, it is
+      used as a key to fetch weight tensor from the `features`. If it is a
+      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+      weight_column.normalizer_fn is applied on it to get weight tensor.
+    optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+      to Adagrad optimizer.
+    activation_fn: Activation function applied to each layer. If `None`, will
+      use `tf.nn.relu`.
+    dropout: When not `None`, the probability we will drop out a given
+      coordinate.
+    input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+      `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+    config: `RunConfig` object to configure the runtime settings.
+    warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+      `WarmStartSettings` object to fully configure warm-starting.  If the
+      string filepath is provided instead of a `WarmStartSettings`, then all
+      weights are warm-started, and it is assumed that vocabularies and Tensor
+      names are unchanged.
+    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+      reduce training loss over batch. Defaults to `SUM`.
+
+  Returns:
+    DNNRegressor with layer annotations.
+  """
+
+  original = dnn.DNNRegressor(
+      hidden_units=hidden_units,
+      feature_columns=feature_columns,
+      model_dir=model_dir,
+      label_dimension=label_dimension,
+      weight_column=weight_column,
+      optimizer=optimizer,
+      activation_fn=activation_fn,
+      dropout=dropout,
+      input_layer_partitioner=input_layer_partitioner,
+      config=config,
+      warm_start_from=warm_start_from,
+      loss_reduction=loss_reduction,
+  )
+
+  def _model_fn(features, labels, mode, config):
+    with _monkey_patch(
+        feature_column_lib, 'input_layer',
+        make_input_layer_with_layer_annotations(feature_column_lib.input_layer,
+                                                mode)):
+      return original.model_fn(features, labels, mode, config)
+
+  return estimator.Estimator(
+      model_fn=_model_fn,
+      model_dir=model_dir,
+      config=config,
+      warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
new file mode 100644
index 0000000..2fe3d4c
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
@@ -0,0 +1,611 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dnn_with_layer_annotations.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import dnn_with_layer_annotations
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import dnn_testing_utils
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import queue_runner
+
+try:
+  # pylint: disable=g-import-not-at-top
+  import pandas as pd
+  HAS_PANDAS = True
+except IOError:
+  # Pandas writes a temporary file during import. If it fails, don't use pandas.
+  HAS_PANDAS = False
+except ImportError:
+  HAS_PANDAS = False
+
+
+def _dnn_classifier_fn(*args, **kwargs):
+  return dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+      *args, **kwargs)
+
+
+class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+                          test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
+                                                       _dnn_regressor_fn)
+
+
+class DNNWithLayerAnnotationsClassifierEvaluateTest(
+    dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+        self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsPredictTest(
+    dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+        self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsTrainTest(
+    dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+        self, _dnn_classifier_fn)
+
+
+def _dnn_regressor_fn(*args, **kwargs):
+  return dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+      *args, **kwargs)
+
+
+class DNNWithLayerAnnotationsTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def _getLayerAnnotationCollection(self, graph, collection_name):
+    keys = graph.get_collection(
+        dnn_with_layer_annotations.LayerAnnotationsCollectionNames.keys(
+            collection_name))
+    values = graph.get_collection(
+        dnn_with_layer_annotations.LayerAnnotationsCollectionNames.values(
+            collection_name))
+    if len(keys) != len(values):
+      raise ValueError('keys and values should have same length. lengths were: '
+                       '%d and %d, and elements were %s and %s' %
+                       (len(keys), len(values), keys, values))
+    return dict(zip(keys, values))
+
+  def _testAnnotationsPresentForEstimator(self, estimator_class):
+    feature_columns = [
+        feature_column.numeric_column('x', shape=(1,)),
+        feature_column.embedding_column(
+            feature_column.categorical_column_with_vocabulary_list(
+                'y', vocabulary_list=['a', 'b', 'c']),
+            dimension=3)
+    ]
+    estimator = estimator_class(
+        hidden_units=(2, 2),
+        feature_columns=feature_columns,
+        model_dir=self._model_dir)
+    model_fn = estimator.model_fn
+
+    graph = ops.Graph()
+    with graph.as_default():
+      model_fn({
+          'x': array_ops.constant([1.0]),
+          'y': array_ops.constant(['a'])
+      }, {},
+               model_fn_lib.ModeKeys.PREDICT,
+               config=None)
+
+      unprocessed_features = self._getLayerAnnotationCollection(
+          graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+          .UNPROCESSED_FEATURES)
+      processed_features = self._getLayerAnnotationCollection(
+          graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+          .PROCESSED_FEATURES)
+      feature_columns = graph.get_collection(
+          dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+          .FEATURE_COLUMNS)
+
+      self.assertItemsEqual(unprocessed_features.keys(), ['x', 'y'])
+      self.assertEqual(2, len(processed_features.keys()))
+      self.assertEqual(2, len(feature_columns))
+
+  def testAnnotationsPresentForClassifier(self):
+    self._testAnnotationsPresentForEstimator(
+        dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations)
+
+  def testAnnotationsPresentForRegressor(self):
+    self._testAnnotationsPresentForEstimator(
+        dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations)
+
+  def _testCheckpointCompatibleWithNonAnnotatedEstimator(
+      self, train_input_fn, predict_input_fn, non_annotated_class,
+      annotated_class, prediction_key, estimator_args):
+    input_dimension = 2
+    feature_columns = [
+        feature_column.numeric_column('x', shape=(input_dimension,))
+    ]
+    estimator = non_annotated_class(
+        model_dir=self._model_dir,
+        hidden_units=(2, 2),
+        feature_columns=feature_columns,
+        **estimator_args)
+
+    estimator.train(train_input_fn, steps=10)
+
+    predictions = np.array(
+        [x[prediction_key] for x in estimator.predict(predict_input_fn)])
+
+    annotated_estimator = annotated_class(
+        model_dir=self._model_dir,
+        hidden_units=(2, 2),
+        feature_columns=feature_columns,
+        warm_start_from=self._model_dir,
+        **estimator_args)
+
+    annotated_predictions = np.array([
+        x[prediction_key] for x in annotated_estimator.predict(predict_input_fn)
+    ])
+
+    self.assertAllEqual(predictions.shape, annotated_predictions.shape)
+    for i, (a, b) in enumerate(
+        zip(predictions.flatten(), annotated_predictions.flatten())):
+      self.assertAlmostEqual(a, b, msg='index=%d' % i)
+
+  def testCheckpointCompatibleForClassifier(self):
+    n_classes = 2
+    input_dimension = 2
+    batch_size = 10
+    data = np.linspace(
+        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+    x_data = data.reshape(batch_size, input_dimension)
+    y_data = np.reshape(
+        np.rint(data[:batch_size]).astype(np.int64), (batch_size, 1))
+    # learn y = x
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': x_data},
+        y=y_data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+    self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+        train_input_fn,
+        predict_input_fn,
+        dnn.DNNClassifier,
+        dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations,
+        prediction_key=prediction_keys.PredictionKeys.PROBABILITIES,
+        estimator_args={'n_classes': n_classes})
+
+  def testCheckpointCompatibleForRegressor(self):
+    label_dimension = 2
+    batch_size = 10
+    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, label_dimension)
+    # learn y = x
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data}, batch_size=batch_size, shuffle=False)
+
+    self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+        train_input_fn,
+        predict_input_fn,
+        dnn.DNNRegressor,
+        dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations,
+        prediction_key=prediction_keys.PredictionKeys.PREDICTIONS,
+        estimator_args={'label_dimension': label_dimension})
+
+
+class DNNRegressorWithLayerAnnotationsEvaluateTest(
+    dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+        self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsPredictTest(
+    dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+        self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsTrainTest(
+    dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
+
+  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
+    test.TestCase.__init__(self, methodName)
+    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+        self, _dnn_regressor_fn)
+
+
+def _queue_parsed_features(feature_map):
+  tensors_to_enqueue = []
+  keys = []
+  for key, tensor in six.iteritems(feature_map):
+    keys.append(key)
+    tensors_to_enqueue.append(tensor)
+  queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+  input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+  queue_runner.add_queue_runner(
+      queue_runner.QueueRunner(input_queue,
+                               [input_queue.enqueue(tensors_to_enqueue)]))
+  dequeued_tensors = input_queue.dequeue()
+  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+class DNNRegressorWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+                          input_dimension, label_dimension, batch_size):
+    feature_columns = [
+        feature_column.numeric_column('x', shape=(input_dimension,))
+    ]
+    est = dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+        hidden_units=(2, 2),
+        feature_columns=feature_columns,
+        label_dimension=label_dimension,
+        model_dir=self._model_dir)
+
+    # TRAIN
+    num_steps = 10
+    est.train(train_input_fn, steps=num_steps)
+
+    # EVALUTE
+    scores = est.evaluate(eval_input_fn)
+    self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+    self.assertIn('loss', six.iterkeys(scores))
+
+    # PREDICT
+    predictions = np.array([
+        x[prediction_keys.PredictionKeys.PREDICTIONS]
+        for x in est.predict(predict_input_fn)
+    ])
+    self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+    # EXPORT
+    feature_spec = feature_column.make_parse_example_spec(feature_columns)
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+                                       serving_input_receiver_fn)
+    self.assertTrue(gfile.Exists(export_dir))
+
+  def test_numpy_input_fn(self):
+    """Tests complete flow with numpy_input_fn."""
+    label_dimension = 2
+    batch_size = 10
+    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, label_dimension)
+    # learn y = x
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    eval_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data}, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=label_dimension,
+        label_dimension=label_dimension,
+        batch_size=batch_size)
+
+  def test_pandas_input_fn(self):
+    """Tests complete flow with pandas_input_fn."""
+    if not HAS_PANDAS:
+      return
+    label_dimension = 1
+    batch_size = 10
+    data = np.linspace(0., 2., batch_size, dtype=np.float32)
+    x = pd.DataFrame({'x': data})
+    y = pd.Series(data)
+    train_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+    eval_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, shuffle=False)
+    predict_input_fn = pandas_io.pandas_input_fn(
+        x=x, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=label_dimension,
+        label_dimension=label_dimension,
+        batch_size=batch_size)
+
+  def test_input_fn_from_parse_example(self):
+    """Tests complete flow with input_fn constructed from parse_example."""
+    label_dimension = 2
+    batch_size = 10
+    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, label_dimension)
+
+    serialized_examples = []
+    for datum in data:
+      example = example_pb2.Example(
+          features=feature_pb2.Features(
+              feature={
+                  'x':
+                      feature_pb2.Feature(
+                          float_list=feature_pb2.FloatList(value=datum)),
+                  'y':
+                      feature_pb2.Feature(
+                          float_list=feature_pb2.FloatList(value=datum)),
+              }))
+      serialized_examples.append(example.SerializeToString())
+
+    feature_spec = {
+        'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+        'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+    }
+
+    def _train_input_fn():
+      feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+      features = _queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _eval_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = _queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _predict_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = _queue_parsed_features(feature_map)
+      features.pop('y')
+      return features, None
+
+    self._test_complete_flow(
+        train_input_fn=_train_input_fn,
+        eval_input_fn=_eval_input_fn,
+        predict_input_fn=_predict_input_fn,
+        input_dimension=label_dimension,
+        label_dimension=label_dimension,
+        batch_size=batch_size)
+
+
+class DNNClassifierWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _as_label(self, data_in_float):
+    return np.rint(data_in_float).astype(np.int64)
+
+  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+                          input_dimension, n_classes, batch_size):
+    feature_columns = [
+        feature_column.numeric_column('x', shape=(input_dimension,))
+    ]
+    est = dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+        hidden_units=(2, 2),
+        feature_columns=feature_columns,
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # TRAIN
+    num_steps = 10
+    est.train(train_input_fn, steps=num_steps)
+
+    # EVALUTE
+    scores = est.evaluate(eval_input_fn)
+    self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+    self.assertIn('loss', six.iterkeys(scores))
+
+    # PREDICT
+    predicted_proba = np.array([
+        x[prediction_keys.PredictionKeys.PROBABILITIES]
+        for x in est.predict(predict_input_fn)
+    ])
+    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+    # EXPORT
+    feature_spec = feature_column.make_parse_example_spec(feature_columns)
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+                                       serving_input_receiver_fn)
+    self.assertTrue(gfile.Exists(export_dir))
+
+  def test_numpy_input_fn(self):
+    """Tests complete flow with numpy_input_fn."""
+    n_classes = 3
+    input_dimension = 2
+    batch_size = 10
+    data = np.linspace(
+        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+    x_data = data.reshape(batch_size, input_dimension)
+    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+    # learn y = x
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': x_data},
+        y=y_data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    eval_input_fn = numpy_io.numpy_input_fn(
+        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        n_classes=n_classes,
+        batch_size=batch_size)
+
+  def test_pandas_input_fn(self):
+    """Tests complete flow with pandas_input_fn."""
+    if not HAS_PANDAS:
+      return
+    input_dimension = 1
+    n_classes = 3
+    batch_size = 10
+    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
+    x = pd.DataFrame({'x': data})
+    y = pd.Series(self._as_label(data))
+    train_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+    eval_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, shuffle=False)
+    predict_input_fn = pandas_io.pandas_input_fn(
+        x=x, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        n_classes=n_classes,
+        batch_size=batch_size)
+
+  def test_input_fn_from_parse_example(self):
+    """Tests complete flow with input_fn constructed from parse_example."""
+    input_dimension = 2
+    n_classes = 3
+    batch_size = 10
+    data = np.linspace(
+        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, input_dimension)
+
+    serialized_examples = []
+    for datum in data:
+      example = example_pb2.Example(
+          features=feature_pb2.Features(
+              feature={
+                  'x':
+                      feature_pb2.Feature(
+                          float_list=feature_pb2.FloatList(value=datum)),
+                  'y':
+                      feature_pb2.Feature(
+                          int64_list=feature_pb2.Int64List(
+                              value=self._as_label(datum[:1]))),
+              }))
+      serialized_examples.append(example.SerializeToString())
+
+    feature_spec = {
+        'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+        'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+    }
+
+    def _train_input_fn():
+      feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+      features = _queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _eval_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = _queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _predict_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = _queue_parsed_features(feature_map)
+      features.pop('y')
+      return features, None
+
+    self._test_complete_flow(
+        train_input_fn=_train_input_fn,
+        eval_input_fn=_eval_input_fn,
+        predict_input_fn=_predict_input_fn,
+        input_dimension=input_dimension,
+        n_classes=n_classes,
+        batch_size=batch_size)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd0..98660bb 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@
 from tensorflow.python.ops import rnn
 from tensorflow.python.ops import rnn_cell
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
 from tensorflow.python.summary import summary
 from tensorflow.python.training import optimizer as optimizer_lib
 from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@
                weight_column=None,
                label_vocabulary=None,
                optimizer='Adagrad',
+               loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
                input_layer_partitioner=None,
                config=None):
     """Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@
         string.
       optimizer: An instance of `tf.Optimizer` or string specifying optimizer
         type. Defaults to Adagrad optimizer.
+      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
       input_layer_partitioner: Optional. Partitioner for input layer. Defaults
         to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
       config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@
     if n_classes == 2:
       head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint: disable=protected-access
           weight_column=weight_column,
-          label_vocabulary=label_vocabulary)
+          label_vocabulary=label_vocabulary,
+          loss_reduction=loss_reduction)
     else:
       head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
-          n_classes, weight_column=weight_column,
-          label_vocabulary=label_vocabulary)
+          n_classes,
+          weight_column=weight_column,
+          label_vocabulary=label_vocabulary,
+          loss_reduction=loss_reduction)
+
     def _model_fn(features, labels, mode, config):
       return _rnn_model_fn(
           features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b403..1aebed3 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@
 
     # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
     # See that test for loss calculation.
-    mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+    mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
 
     sequence_feature_columns = [
         seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@
 
     # Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
     # See that test for loss calculation.
-    mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+    mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
 
     sequence_feature_columns = [
         seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@
     # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
     # loss = -label * ln(p) - (1 - label) * ln(1 - p)
     #      = [[0.436326], [0.683335]]
+    # sum_over_batch_size = (0.436326 + 0.683335)/2
     expected_metrics = {
-        ops.GraphKeys.GLOBAL_STEP: global_step,
-        metric_keys.MetricKeys.LOSS: 1.119661,
-        metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
-        metric_keys.MetricKeys.ACCURACY: 1.0,
-        metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
-        metric_keys.MetricKeys.LABEL_MEAN: 0.5,
-        metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+        ops.GraphKeys.GLOBAL_STEP:
+            global_step,
+        metric_keys.MetricKeys.LOSS:
+            0.559831,
+        metric_keys.MetricKeys.LOSS_MEAN:
+            0.559831,
+        metric_keys.MetricKeys.ACCURACY:
+            1.0,
+        metric_keys.MetricKeys.PREDICTION_MEAN:
+            0.429262,
+        metric_keys.MetricKeys.LABEL_MEAN:
+            0.5,
+        metric_keys.MetricKeys.ACCURACY_BASELINE:
+            0.5,
         # With default threshold of 0.5, the model is a perfect classifier.
-        metric_keys.MetricKeys.RECALL: 1.0,
-        metric_keys.MetricKeys.PRECISION: 1.0,
+        metric_keys.MetricKeys.RECALL:
+            1.0,
+        metric_keys.MetricKeys.PRECISION:
+            1.0,
         # Positive example is scored above negative, so AUC = 1.0.
-        metric_keys.MetricKeys.AUC: 1.0,
-        metric_keys.MetricKeys.AUC_PR: 1.0,
+        metric_keys.MetricKeys.AUC:
+            1.0,
+        metric_keys.MetricKeys.AUC_PR:
+            1.0,
     }
     self.assertAllClose(
         sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@
     #                          [0.059494, 0.572639, 0.367866]]
     # loss = -1. * log(softmax[label])
     #      = [[2.105432], [0.557500]]
+    # sum_over_batch_size = (2.105432 + 0.557500)/2
     expected_metrics = {
         ops.GraphKeys.GLOBAL_STEP: global_step,
-        metric_keys.MetricKeys.LOSS: 2.662932,
+        metric_keys.MetricKeys.LOSS: 1.331465,
         metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
         metric_keys.MetricKeys.ACCURACY: 0.5,
     }
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index bb5140a..6aa62fb 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -126,7 +126,7 @@
     observed *= num_rows / 3. if test_rows else num_cols / 2.
     want_weight_sum = unobserved + observed
 
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       wals_model = factorization_ops.WALSModel(
           input_rows=num_rows,
           input_cols=num_cols,
@@ -161,7 +161,7 @@
   def _run_test_process_input(self,
                               use_factors_weights_cache,
                               compute_loss=False):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       self._wals_inputs = self.sparse_input()
       sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
       num_rows = 5
@@ -330,7 +330,7 @@
   def _run_test_process_input_transposed(self,
                                          use_factors_weights_cache,
                                          compute_loss=False):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       self._wals_inputs = self.sparse_input()
       sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
       num_rows = 5
@@ -505,7 +505,7 @@
   # trigger the more efficient ALS updates.
   # Here we test that those two give identical results.
   def _run_test_als(self, use_factors_weights_cache):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       self._wals_inputs = self.sparse_input()
       col_init = np.random.rand(7, 3)
       als_model = factorization_ops.WALSModel(
@@ -583,7 +583,7 @@
           atol=1e-2)
 
   def _run_test_als_transposed(self, use_factors_weights_cache):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       self._wals_inputs = self.sparse_input()
       col_init = np.random.rand(7, 3)
       als_model = factorization_ops.WALSModel(
@@ -673,7 +673,7 @@
     rows = 15
     cols = 11
     dims = 3
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       data = np.dot(np.random.rand(rows, 3), np.random.rand(
           3, cols)).astype(np.float32) / 3.0
       indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -703,7 +703,7 @@
     cols = 11
     dims = 3
 
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       data = np.dot(np.random.rand(rows, 3), np.random.rand(
           3, cols)).astype(np.float32) / 3.0
       indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -736,7 +736,7 @@
     def keep_index(x):
       return not (x[0] + x[1]) % 4
 
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       row_wts = 0.1 + np.random.rand(rows)
       col_wts = 0.1 + np.random.rand(cols)
       data = np.dot(np.random.rand(rows, 3), np.random.rand(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
index 888c3c2..112e4d2 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
@@ -99,7 +99,7 @@
     logging.info('Numpy took %f', time.time() - start_time)
 
     start_time = time.time()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       op = gmm_ops._covariance(
           constant_op.constant(
               data.T, dtype=dtypes.float32), False)
@@ -120,7 +120,7 @@
     graph = ops.Graph()
     with graph.as_default() as g:
       g.seed = 5
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         data = constant_op.constant(self.data, dtype=dtypes.float32)
         loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
             data, 'random', num_classes, random_seed=self.seed)
@@ -144,7 +144,7 @@
   def testParams(self):
     """Tests that the params work as intended."""
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Experiment 1. Update weights only.
       data = constant_op.constant(self.data, dtype=dtypes.float32)
       gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index 88eb9cf..1ab5418 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -232,7 +232,7 @@
     self.assertEqual(features.shape, parsed_feature_dict.shape)
     self.assertEqual(features.dtype, parsed_feature_dict.dtype)
     # Then check that running the tensor yields the original list of points.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       parsed_points = sess.run(parsed_feature_dict)
       self.assertAllEqual(self.points, parsed_points)
 
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 31820a1..9bdbd05 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -336,7 +336,7 @@
     loss = self._model.evaluate(
         input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
 
-    with self.test_session():
+    with self.cached_session():
       true_loss = self.calculate_loss()
 
     self.assertNear(
@@ -354,7 +354,7 @@
     loss = self._model.evaluate(
         input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
 
-    with self.test_session():
+    with self.cached_session():
       true_loss = self.calculate_loss()
 
     self.assertNear(
@@ -440,7 +440,7 @@
                          math_ops.logical_not(is_row_sweep_var)))
     mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sweep_hook = wals_lib._SweepHook(
           is_row_sweep_var,
           is_sweep_done_var,
@@ -491,7 +491,7 @@
     train_op = state_ops.assign_add(completed_sweeps, 1)
     hook.begin()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([variables.global_variables_initializer()])
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(train_op)
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index b1b5126..45a67ac 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -24,11 +24,13 @@
 from tensorflow.contrib.util import loader
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import resource_loader
+from tensorflow.python.util.deprecation import deprecated
 
 _ffmpeg_so = loader.load_op_library(
     resource_loader.get_path_to_datafile('ffmpeg.so'))
 
 
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
 def decode_audio(contents, file_format=None, samples_per_second=None,
                  channel_count=None, stream=None):
   """Create an op that decodes the contents of an audio file.
@@ -69,6 +71,7 @@
 ops.NotDifferentiable('DecodeAudio')
 
 
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
 def encode_audio(audio, file_format=None, samples_per_second=None):
   """Creates an op that encodes an audio file using sampled audio from a tensor.
 
@@ -95,6 +98,7 @@
 ops.NotDifferentiable('EncodeAudio')
 
 
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
 def decode_video(contents):
   """Create an op that decodes the contents of a video file.
 
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
index 4f59136..77a4241 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
@@ -82,7 +82,7 @@
 
   def testNoTensor(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
     with self.assertRaises(errors_impl.OpError):
       self.assertAllEqual(
@@ -90,7 +90,7 @@
 
   def testGetTensor(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -103,7 +103,7 @@
 
   def testGetAllVariables(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _create_checkpoints(session, checkpoint_dir)
     self.assertEqual(
         checkpoint_utils.list_variables(checkpoint_dir),
@@ -112,7 +112,7 @@
 
   def testInitFromCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -146,7 +146,7 @@
 
   def testInitWithScopeDoesNotCaptureSuffixes(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
 
     with ops.Graph().as_default() as g:
@@ -165,7 +165,7 @@
 
   def testInitFromRootCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -189,7 +189,7 @@
 
   def testInitToRootCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -212,7 +212,7 @@
 
   def testInitFromPartitionVar(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1 = _create_partition_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -266,7 +266,7 @@
 
   def testInitFromCheckpointMissing(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 2479fe5..b1820c1 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -39,7 +39,7 @@
 class LocalVariabletest(test.TestCase):
 
   def test_local_variable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEquals([], variables_lib.local_variables())
       value0 = 42
       variables_lib2.local_variable(value0)
@@ -55,7 +55,7 @@
 class ReduceSumNTest(test.TestCase):
 
   def test_reduce_sum_n(self):
-    with self.test_session():
+    with self.cached_session():
       a = constant_op.constant(1)
       b = constant_op.constant([2])
       c = constant_op.constant([[3, 4], [5, 6]])
@@ -119,13 +119,13 @@
                                   }))
 
   def test_with_shape_invalid_expected_shape(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertRaisesRegexp(ValueError, "Invalid rank",
                               tensor_util.with_shape, [[1], [2]],
                               constant_op.constant(1.0))
 
   def test_with_shape_invalid_type(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertRaisesRegexp(ValueError, "Invalid dtype",
                               tensor_util.with_shape, [1.1],
                               constant_op.constant([1.0]))
@@ -138,7 +138,7 @@
                               constant_op.constant(1.0))
 
   def test_with_shape_0(self):
-    with self.test_session():
+    with self.cached_session():
       value = 42
       shape = [0]
       unexpected_shapes = [[1], [2], [1, 1]]
@@ -150,7 +150,7 @@
           unexpected_shapes)
 
   def test_with_shape_1(self):
-    with self.test_session():
+    with self.cached_session():
       value = [42]
       shape = [1]
       unexpected_shapes = [[0], [2], [1, 1]]
@@ -162,7 +162,7 @@
           unexpected_shapes)
 
   def test_with_shape_2(self):
-    with self.test_session():
+    with self.cached_session():
       value = [42, 43]
       shape = [2]
       unexpected_shapes = [[0], [1], [2, 1]]
@@ -174,7 +174,7 @@
           unexpected_shapes)
 
   def test_with_shape_2x2(self):
-    with self.test_session():
+    with self.cached_session():
       value = [[42, 43], [44, 45]]
       shape = [2, 2]
       unexpected_shapes = [[0], [1], [2, 1]]
@@ -196,7 +196,7 @@
       np.testing.assert_array_equal(value, tensor_with_shape.eval())
 
   def test_with_shape_none(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_no_shape = array_ops.placeholder(dtypes.float32)
 
       compatible_shape = [2, 2]
@@ -220,7 +220,7 @@
 
   @test_util.enable_c_shapes
   def test_with_shape_partial(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_partial_shape = array_ops.placeholder(dtypes.float32)
       tensor_partial_shape.set_shape([None, 2])
 
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 0ccb458..716bb87 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -174,7 +174,7 @@
 
     // Input bias is a 1-D tensor, with size matching output depth.
     const Tensor& bias = context->input(kBias);
-    OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
+    OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
 
     const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
     const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index ab98865..7243f15 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -184,7 +184,7 @@
       return _get_estimator_spec(
           mode, gan_model, generator_loss_fn, discriminator_loss_fn,
           get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
-          get_hooks_fn)
+          get_hooks_fn, use_loss_summaries)
 
     super(GANEstimator, self).__init__(
         model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -211,15 +211,17 @@
 def _get_estimator_spec(
     mode, gan_model, generator_loss_fn, discriminator_loss_fn,
     get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
-    get_hooks_fn=None):
+    get_hooks_fn=None, use_loss_summaries=True):
   """Get the EstimatorSpec for the current mode."""
   if mode == model_fn_lib.ModeKeys.PREDICT:
     estimator_spec = model_fn_lib.EstimatorSpec(
         mode=mode, predictions=gan_model.generated_data)
   else:
     gan_loss = tfgan_tuples.GANLoss(
-        generator_loss=generator_loss_fn(gan_model),
-        discriminator_loss=discriminator_loss_fn(gan_model))
+        generator_loss=generator_loss_fn(
+            gan_model, add_summaries=use_loss_summaries),
+        discriminator_loss=discriminator_loss_fn(
+            gan_model, add_summaries=use_loss_summaries))
     if mode == model_fn_lib.ModeKeys.EVAL:
       estimator_spec = _get_eval_estimator_spec(
           gan_model, gan_loss, get_eval_metric_ops_fn)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 9ac9c6c..83f8dd6 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -116,7 +116,7 @@
       discriminator_fn=None)
 
 
-def dummy_loss_fn(gan_model):
+def dummy_loss_fn(gan_model, add_summaries=True):
   return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
                              gan_model.discriminator_gen_outputs)
 
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 9f5fee4..e3c780a 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -51,7 +51,7 @@
     loss = self._g_loss_fn(self._discriminator_gen_outputs)
     self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
     self.assertEqual(self._generator_loss_name, loss.op.name)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
 
   def test_discriminator_all_correct(self):
@@ -59,7 +59,7 @@
         self._discriminator_real_outputs, self._discriminator_gen_outputs)
     self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
     self.assertEqual(self._discriminator_loss_name, loss.op.name)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
 
   def test_generator_loss_collection(self):
@@ -90,7 +90,7 @@
     loss = self._g_loss_fn(
         array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
     self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
 
   def test_discriminator_patch(self):
@@ -98,7 +98,7 @@
         array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
         array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
     self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
 
   def test_generator_loss_with_placeholder_for_logits(self):
@@ -108,7 +108,7 @@
     loss = self._g_loss_fn(logits, weights=weights)
     self.assertEqual(logits.dtype, loss.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: [[10.0, 4.4, -5.5, 3.6]],
@@ -125,7 +125,7 @@
         logits, logits2, real_weights=real_weights,
         generated_weights=generated_weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: [self._discriminator_real_outputs_np],
@@ -136,7 +136,7 @@
   def test_generator_with_python_scalar_weight(self):
     loss = self._g_loss_fn(
         self._discriminator_gen_outputs, weights=self._weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss * self._weights,
                              loss.eval(), 4)
 
@@ -144,14 +144,14 @@
     loss = self._d_loss_fn(
         self._discriminator_real_outputs, self._discriminator_gen_outputs,
         real_weights=self._weights, generated_weights=self._weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss * self._weights,
                              loss.eval(), 4)
 
   def test_generator_with_scalar_tensor_weight(self):
     loss = self._g_loss_fn(self._discriminator_gen_outputs,
                            weights=constant_op.constant(self._weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss * self._weights,
                              loss.eval(), 4)
 
@@ -160,7 +160,7 @@
     loss = self._d_loss_fn(
         self._discriminator_real_outputs, self._discriminator_gen_outputs,
         real_weights=weights, generated_weights=weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss * self._weights,
                              loss.eval(), 4)
 
@@ -284,7 +284,7 @@
     self.assertEqual(
         self._discriminator_gen_classification_logits.dtype, loss.dtype)
     self.assertEqual(self._generator_loss_name, loss.op.name)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
 
   def test_discriminator_all_correct(self):
@@ -292,7 +292,7 @@
     self.assertEqual(
         self._discriminator_gen_classification_logits.dtype, loss.dtype)
     self.assertEqual(self._discriminator_loss_name, loss.op.name)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
 
   def test_generator_loss_collection(self):
@@ -319,14 +319,14 @@
     patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
                   self._generator_kwargs.items()}
     loss = self._g_loss_fn(**patch_args)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
 
   def test_discriminator_patch(self):
     patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
                   self._discriminator_kwargs.items()}
     loss = self._d_loss_fn(**patch_args)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
 
   def test_generator_loss_with_placeholder_for_logits(self):
@@ -334,7 +334,7 @@
     one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
 
     loss = self._g_loss_fn(gen_logits, one_hot_labels)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(
           loss, feed_dict={
               gen_logits: self._discriminator_gen_classification_logits_np,
@@ -349,7 +349,7 @@
 
     loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(
           loss, feed_dict={
               gen_logits: self._discriminator_gen_classification_logits_np,
@@ -360,7 +360,7 @@
 
   def test_generator_with_python_scalar_weight(self):
     loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss * self._weights,
                              loss.eval(), 4)
 
@@ -368,14 +368,14 @@
     loss = self._d_loss_fn(
         real_weights=self._weights, generated_weights=self._weights,
         **self._discriminator_kwargs)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss * self._weights,
                              loss.eval(), 4)
 
   def test_generator_with_scalar_tensor_weight(self):
     loss = self._g_loss_fn(
         weights=constant_op.constant(self._weights), **self._generator_kwargs)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_g_loss * self._weights,
                              loss.eval(), 4)
 
@@ -383,7 +383,7 @@
     weights = constant_op.constant(self._weights)
     loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
                            **self._discriminator_kwargs)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(self._expected_d_loss * self._weights,
                              loss.eval(), 4)
 
@@ -404,7 +404,7 @@
     loss = self._penalty_fn(**self._kwargs)
     self.assertEqual(self._expected_dtype, loss.dtype)
     self.assertEqual(self._expected_op_name, loss.op.name)
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
 
@@ -419,13 +419,13 @@
 
   def test_python_scalar_weight(self):
     loss = self._penalty_fn(weights=2.3, **self._kwargs)
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
 
   def test_scalar_tensor_weight(self):
     loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
 
@@ -472,7 +472,7 @@
         self._kwargs['discriminator_scope'])
     self.assertEqual(generated_data.dtype, loss.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables.global_variables_initializer().run()
       loss = sess.run(loss,
                       feed_dict={
@@ -494,7 +494,7 @@
         one_sided=True)
     self.assertEqual(generated_data.dtype, loss.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables.global_variables_initializer().run()
       loss = sess.run(loss,
                       feed_dict={
@@ -516,7 +516,7 @@
         self._kwargs['discriminator_scope'],
         target=2.0)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables.global_variables_initializer().run()
       loss = sess.run(
           loss,
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbf..25d74a8 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@
 
   def consistency_test(self):
     self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(arg_loss(**loss_args).eval(),
                        tuple_loss(_tuple_from_dict(loss_args)).eval())
 
@@ -241,7 +241,7 @@
         self.discriminator_generated_data_source_predication)
     wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       loss_result, wrapped_loss_result = sess.run(
           [loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@
         self.discriminator_generated_data_source_predication)
     wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       loss_result, wrapped_loss_result = sess.run(
           [loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@
         discriminator_scope=self.discriminator_scope)
     wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       loss_result, wrapped_loss_result = sess.run(
           [loss_result_tensor, wrapped_loss_result_tensor])
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index 80b2d3e..2bf6097 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/core/platform/file_system.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 static const size_t kSyncMarkerSize = 16;
@@ -332,9 +333,10 @@
   };
   DataTypeVector output_types_;
 };
-}  // namespace
 
 REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU),
                         SequenceFileDatasetOp);
 
+}  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index c7b4e2f..be915ef 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -49,7 +49,7 @@
     y_solved = odes.odeint(func, y0, t)
     self.assertIn('odeint', y_solved.name)
     self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_solved = sess.run(y_solved)
     y_true = np.exp(t)
     self.assertAllClose(y_true, y_solved)
@@ -62,7 +62,7 @@
     func = lambda y, t: k * y
     t = np.linspace(0.0, 1.0, 11)
     y_solved = odes.odeint(func, 1.0 + 0.0j, t)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_solved = sess.run(y_solved)
     y_true = np.exp(k * t)
     self.assertAllClose(y_true, y_solved)
@@ -74,7 +74,7 @@
     func = lambda t, y: (y - t)**2 + 1.0
     t = np.linspace(0.0, 1.0, 11)
     y_solved = odes.odeint(func, np.float64(0.5), t)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_solved = sess.run(y_solved)
     y_true = 1.0 / (2.0 - t) + t
     self.assertAllClose(y_true, y_solved)
@@ -96,7 +96,7 @@
     t = np.linspace(0.0, 1.0, 11)
 
     y_solved = odes.odeint(func, y0, t)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_solved = sess.run(y_solved)
 
     y_true = np.zeros((len(t), 2, 1))
@@ -113,7 +113,7 @@
       y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
       self.assertEqual(y_solved.get_shape(),
                        tensor_shape.TensorShape(expected_shape))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         y_solved = sess.run(y_solved)
         self.assertEquals(y_solved.shape, expected_shape)
 
@@ -126,7 +126,7 @@
       for t_dtype in [dtypes.float32, dtypes.float64]:
         y0 = math_ops.cast(1.0, y0_dtype)
         y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           y_solved = sess.run(y_solved)
         expected = np.asarray(np.exp(t))
         self.assertAllClose(y_solved, expected, rtol=1e-5)
@@ -148,13 +148,13 @@
         self.y0, [0, 1],
         method='dopri5',
         options={'max_num_steps': 0})
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    'max_num_steps'):
         sess.run(y)
 
     y = odes.odeint(self.func, self.y0, [1, 0])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    'monotonic increasing'):
         sess.run(y)
@@ -164,7 +164,7 @@
     times0 = np.linspace(0, 10, num=11, dtype=float)
     times1 = np.linspace(0, 10, num=101, dtype=float)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_solved_0, info_0 = sess.run(
           odes.odeint(self.func, self.y0, times0, full_output=True))
       y_solved_1, info_1 = sess.run(
@@ -179,7 +179,7 @@
     t = [0, 20]
     kwargs = dict(
         full_output=True, method='dopri5', options=dict(max_num_steps=2000))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _, info_0 = sess.run(
           odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
       _, info_1 = sess.run(
@@ -196,7 +196,7 @@
     new_step = odes._optimal_step_size(
         last_step=constant_op.constant(1.0),
         error_ratio=constant_op.constant(1.0))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       new_step = sess.run(new_step)
     self.assertAllClose(new_step, 0.9)
 
@@ -204,7 +204,7 @@
     new_step = odes._optimal_step_size(
         last_step=constant_op.constant(1.0),
         error_ratio=constant_op.constant(0.0))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       new_step = sess.run(new_step)
     self.assertAllClose(new_step, 10.0)
 
@@ -212,7 +212,7 @@
     new_step = odes._optimal_step_size(
         last_step=constant_op.constant(1.0),
         error_ratio=constant_op.constant(1e6))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       new_step = sess.run(new_step)
     self.assertAllClose(new_step, 0.2)
 
@@ -229,13 +229,13 @@
     y_fit = array_ops.stack(
         [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
     y_expected = f(times)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_actual = sess.run(y_fit)
       self.assertAllClose(y_expected, y_actual)
 
     # attempt interpolation outside bounds
     y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors_impl.InvalidArgumentError):
         sess.run(y_invalid)
 
@@ -251,7 +251,7 @@
     y0 = [0., 1.]
     y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_grid_array = sess.run(y_grid)
 
     np.testing.assert_allclose(
@@ -265,7 +265,7 @@
     y0 = [1.]
     y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y_grid_array = sess.run(y_grid)
 
     np.testing.assert_allclose(
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 7ede193..124515e 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -109,7 +109,7 @@
     return sparse_ids, sparse_weights
 
   def test_safe_embedding_lookup_sparse_return_zero_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -122,7 +122,7 @@
            3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
 
   def test_safe_embedding_lookup_sparse_return_special_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -136,7 +136,7 @@
            embedding_weights[0][2], embedding_weights[0][3]])
 
   def test_safe_embedding_lookup_sparse_no_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, _ = self._ids_and_weights_2d()
 
@@ -150,7 +150,7 @@
                embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
 
   def test_safe_embedding_lookup_sparse_partitioned(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, _ = self._ids_and_weights_2d()
 
@@ -164,7 +164,7 @@
                            (embedding_weights[0] + embedding_weights[1]) / 2.0])
 
   def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -179,7 +179,7 @@
                         embedding_weights, sparse_ids, sparse_weights)
 
   def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -192,7 +192,7 @@
       ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
 
   def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -208,7 +208,7 @@
             ]])
 
   def test_safe_embedding_lookup_sparse_3d_no_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, _ = self._ids_and_weights_3d()
 
@@ -224,7 +224,7 @@
           ]])
 
   def test_safe_embedding_lookup_sparse_3d_partitioned(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, _ = self._ids_and_weights_3d()
 
@@ -241,7 +241,7 @@
 
   def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
       self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -276,7 +276,7 @@
     return embedding_weights
 
   def test_scattered_embedding_consistency(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       values = constant_op.constant(["foo", "foo"])
 
@@ -288,7 +288,7 @@
                           embedding_lookup_result[1])
 
   def test_scattered_embedding_multiple_partition(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=7)
       values = constant_op.constant([4, 4, 5])
 
@@ -304,7 +304,7 @@
       self.assertGreater(embedding_diff, 0)
 
   def test_scattered_embedding_coverage(self):
-    with self.test_session():
+    with self.cached_session():
       size = 8
       embedding_weights = self._random_weights(size=size, num_shards=3)
       values = constant_op.constant(["foo"])
@@ -316,7 +316,7 @@
       self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
 
   def test_scattered_embedding_multi_dimension(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       values = constant_op.constant([["foo", "bar", "bar"],
                                      ["bar", "bar", "foo"]])
@@ -329,7 +329,7 @@
                           embedding_lookup_result[1][2])
 
   def test_scattered_embedding_lookup_sparse(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_tensor = sparse_tensor_lib.SparseTensor(
           values=["foo", "bar", "foo", "bar"],
@@ -358,7 +358,7 @@
     embeds = np.random.randn(n_embed, d_embed)
     idx = np.random.randint(0, n_embed, idx_shape)
 
-    with self.test_session():
+    with self.cached_session():
       embedded_np = embeds[idx]
       embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
 
@@ -370,7 +370,7 @@
     idx = np.random.randint(0, 5, 10)
     idx2d = np.random.randint(0, 5, (10, 2))
 
-    with self.test_session():
+    with self.cached_session():
       embedded_np = embeds[idx]
       embedded_np2d = embeds[idx2d]
       embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -408,7 +408,7 @@
     return embedding_weights
 
   def test_hashed_embedding_consistency(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       values = constant_op.constant(["foo", "foo"])
       # The first three sampled_candidates are equal, so the first three
@@ -429,7 +429,7 @@
                           embedding_lookup_result[1][3])
 
   def test_hashed_embedding_multi_dimension(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       values = constant_op.constant([["foo", "bar", "bar"],
                                      ["bar", "bar", "foo"]])
@@ -467,7 +467,7 @@
 
   def test_output_shape(self):
     """Verifies the shape of the output tensor."""
-    with self.test_session():
+    with self.cached_session():
       sp_values = sparse_tensor_lib.SparseTensor(
           values=["a", "a", "b", "c", "d", "e", "f"],
           indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -481,7 +481,7 @@
 
   def test_output_values(self):
     """Verifies the values in a trivial case."""
-    with self.test_session():
+    with self.cached_session():
       sp_values = sparse_tensor_lib.SparseTensor(
           values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
       params = constant_op.constant([.1, .2, .3])
@@ -495,7 +495,7 @@
 
   def test_output_values_with_sampled_candidates(self):
     """Verifies the values for given sampled_candidates."""
-    with self.test_session():
+    with self.cached_session():
       sp_values = sparse_tensor_lib.SparseTensor(
           values=["a", "a", "b", "c", "d", "e", "f"],
           indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -520,7 +520,7 @@
 
   def test_output_values_with_sign_hash(self):
     """Verifies the values in a trivial case with hash_signs=True."""
-    with self.test_session():
+    with self.cached_session():
       sp_values = sparse_tensor_lib.SparseTensor(
           values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
       params = constant_op.constant([.1, .1, .1])
@@ -537,7 +537,7 @@
 
   def test_distributive_property(self):
     """Verifies the distributive property of matrix multiplication."""
-    with self.test_session():
+    with self.cached_session():
       params = constant_op.constant([.1, .2, .3])
       sp_values_a = sparse_tensor_lib.SparseTensor(
           values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
@@ -710,7 +710,7 @@
         [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
                                            dtypes.float64], [True, False]):
 
-      with self.test_session():
+      with self.cached_session():
         p, params, feed_dict = _EmbeddingParams(
             num_shards, vocab_size, shape=param_shape, dtype=dtype)
         embedding_sum = \
@@ -749,7 +749,7 @@
     for num_shards, combiner, dtype, ignore_weights in itertools.product(
         [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
                                            dtypes.float64], [True, False]):
-      with self.test_session():
+      with self.cached_session():
         x, params, _ = _EmbeddingParams(
             num_shards, vocab_size, shape=param_shape, dtype=dtype)
 
@@ -767,7 +767,7 @@
       self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
       sp_ids = sparse_tensor_lib.SparseTensor(
           constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
diff --git a/tensorflow/contrib/layers/python/layers/encoders_test.py b/tensorflow/contrib/layers/python/layers/encoders_test.py
index e8528e9..1a2aa71 100644
--- a/tensorflow/contrib/layers/python/layers/encoders_test.py
+++ b/tensorflow/contrib/layers/python/layers/encoders_test.py
@@ -34,14 +34,14 @@
 class EncodersTest(test.TestCase):
 
   def testBowEncoderSparse(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3]]
       enc = encoders.bow_encoder(docs, 4, 3)
       sess.run(variables.global_variables_initializer())
       self.assertAllEqual([2, 3], enc.eval().shape)
 
   def testBowEncoderSparseTensor(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3]]
       sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
       enc = encoders.bow_encoder(sparse_docs, 4, 3)
@@ -49,28 +49,28 @@
       self.assertAllEqual([2, 3], enc.eval().shape)
 
   def testBowEncoderSparseEmptyRow(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3], [0, 0]]
       enc = encoders.bow_encoder(docs, 4, 5)
       sess.run(variables.global_variables_initializer())
       self.assertAllEqual([3, 5], enc.eval().shape)
 
   def testBowEncoderDense(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3], [0, 0], [0, 0]]
       enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False)
       sess.run(variables.global_variables_initializer())
       self.assertAllEqual([4, 3], enc.eval().shape)
 
   def testBowEncoderSparseTensorDenseLookup(self):
-    with self.test_session():
+    with self.cached_session():
       docs = [[0, 1]]
       sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
       with self.assertRaises(TypeError):
         encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False)
 
   def testBowEncodersSharingEmbeddings(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3]]
       enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test')
       enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
@@ -79,7 +79,7 @@
       self.assertAllEqual(avg_1, avg_2)
 
   def testBowEncodersSharingEmbeddingsInheritedScopes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3]]
       with variable_scope.variable_scope('test'):
         enc_1 = encoders.bow_encoder(docs, 4, 3)
@@ -90,7 +90,7 @@
       self.assertAllEqual(avg_1, avg_2)
 
   def testBowEncodersSharingEmbeddingsSharedScope(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[0, 1], [2, 3]]
       enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow')
       variable_scope.get_variable_scope().reuse_variables()
@@ -100,7 +100,7 @@
       self.assertAllEqual(avg_1, avg_2)
 
   def testBowEncoderReuseEmbeddingsVariable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[1, 1], [2, 3]]
       with variable_scope.variable_scope('test'):
         v = _get_const_var('embeddings', (4, 3),
@@ -111,7 +111,7 @@
       self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval())
 
   def testEmbedSequence(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       docs = [[1, 1], [2, 3]]
       with variable_scope.variable_scope('test'):
         v = _get_const_var('embeddings', (4, 3),
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index e6bbd86..6fb4b9f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -49,7 +49,7 @@
     real_valued = feature_column.real_valued_column("price")
     features = {"price": constant_op.constant([[20.], [110], [-3]])}
     output = feature_column_ops._Transformer(features).transform(real_valued)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(output.eval(), [[20.], [110], [-3]])
 
   def testSparseRealValuedColumnIdentityTransformation(self):
@@ -60,7 +60,7 @@
     features = {"rating": rating_tensor}
     output = feature_column_ops._Transformer(features).transform(
         sparse_real_valued)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(output.values.eval(), rating_tensor.values.eval())
       self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
       self.assertAllEqual(output.dense_shape.eval(),
@@ -80,7 +80,7 @@
                                                         [sparse_real_valued])
     self.assertTrue(sparse_real_valued in output_dict)
     output = output_dict[sparse_real_valued]
-    with self.test_session():
+    with self.cached_session():
       self.assertArrayNear(output.values.eval(), [4.0, 25.0], 1e-5)
       self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
       self.assertAllEqual(output.dense_shape.eval(),
@@ -97,7 +97,7 @@
         features=features, feature_columns=[bucket])
     self.assertEqual(len(output), 1)
     self.assertIn(bucket, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(output[bucket].eval(), [[2], [3], [0]])
 
   def testBucketizedColumnWithMultiDimensions(self):
@@ -109,7 +109,7 @@
         "price": constant_op.constant([[20., 110], [110., 20], [-3, -3]])
     }
     output = feature_column_ops._Transformer(features).transform(bucket)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(output.eval(), [[2, 3], [3, 2], [0, 0]])
 
   def testCachedTransformation(self):
@@ -118,7 +118,7 @@
     # buckets 2, 3, 0
     features = {"price": constant_op.constant([[20.], [110], [-3]])}
     transformer = feature_column_ops._Transformer(features)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       transformer.transform(bucket)
       num_of_ops = len(sess.graph.get_operations())
       # Verify that the second call to transform the same feature
@@ -138,7 +138,7 @@
         features=features, feature_columns=[hashed_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(hashed_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
       self.assertTrue(
           all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -161,7 +161,7 @@
         features=features, feature_columns=[hashed_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(hashed_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
       self.assertTrue(
           all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -177,7 +177,7 @@
     features = {"wire": wire_tensor}
     output = feature_column_ops._Transformer(features).transform(hashed_sparse)
 
-    with self.test_session():
+    with self.cached_session():
       # While the input is a dense Tensor, the output should be a SparseTensor.
       self.assertIsInstance(output, sparse_tensor.SparseTensor)
       self.assertEqual(output.values.dtype, dtypes.int64)
@@ -203,7 +203,7 @@
     self.assertEqual(len(output), 2)
     self.assertIn(hashed_sparse, output)
     self.assertIn(wire_embedding, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(output[wire_embedding].indices.eval(),
                           wire_tensor.indices.eval())
       self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
@@ -223,7 +223,7 @@
         features=features, feature_columns=[keys_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(keys_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
       self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
@@ -241,7 +241,7 @@
     features = {"wire": wire_tensor}
     output = feature_column_ops._Transformer(features).transform(keys_sparse)
 
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       # While the input is a dense Tensor, the output should be a SparseTensor.
       self.assertIsInstance(output, sparse_tensor.SparseTensor)
@@ -264,7 +264,7 @@
         features=features, feature_columns=[hashed_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(hashed_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int32)
       self.assertTrue(
           all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -282,7 +282,7 @@
     wire_tensor = constant_op.constant([[100, 0], [1, 25]])
     features = {"wire": wire_tensor}
     output = feature_column_ops._Transformer(features).transform(hashed_sparse)
-    with self.test_session():
+    with self.cached_session():
       # While the input is a dense Tensor, the output should be a SparseTensor.
       self.assertIsInstance(output, sparse_tensor.SparseTensor)
       self.assertEqual(output.values.dtype, dtypes.int32)
@@ -310,7 +310,7 @@
     self.assertEqual(len(output), 1)
     self.assertIn(weighted_ids, output)
 
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
                           ids_tensor.dense_shape.eval())
@@ -340,7 +340,7 @@
         features=features, feature_columns=[vocab_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(vocab_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
       self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -362,7 +362,7 @@
         features=features, feature_columns=[vocab_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(vocab_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
       self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -386,7 +386,7 @@
         features=features, feature_columns=[vocab_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(vocab_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
       self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -408,7 +408,7 @@
         features=features, feature_columns=[vocab_sparse])
     self.assertEqual(len(output), 1)
     self.assertIn(vocab_sparse, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
       self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -440,7 +440,7 @@
         features=features, feature_columns=[country_language])
     self.assertEqual(len(output), 1)
     self.assertIn(country_language, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[country_language].values.dtype, dtypes.int64)
       self.assertTrue(
           all(x < 15 and x >= 0 for x in output[country_language].values.eval(
@@ -467,7 +467,7 @@
         features=features, feature_columns=[country_price])
     self.assertEqual(len(output), 1)
     self.assertIn(country_price, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[country_price].values.dtype, dtypes.int64)
       self.assertTrue(
           all(x < 15 and x >= 0 for x in output[country_price].values.eval()))
@@ -498,7 +498,7 @@
       weights = column_to_variable[country_price][0]
       grad = array_ops.squeeze(
           gradients_impl.gradients(output, weights)[0].values)
-      with self.test_session():
+      with self.cached_session():
         variables_lib.global_variables_initializer().run()
         self.assertEqual(len(grad.eval()), 6)
 
@@ -537,7 +537,7 @@
         features=features, feature_columns=[wire_country_price])
     self.assertEqual(len(output), 1)
     self.assertIn(wire_country_price, output)
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(output[wire_country_price].values.dtype, dtypes.int64)
       self.assertTrue(
           all(x < 15 and x >= 0 for x in output[wire_country_price].values.eval(
@@ -600,7 +600,7 @@
     columns = [one_hot_column, embedding_column, real_valued_column]
     output = feature_column_ops.input_from_feature_columns(features, columns)
     output_core = fc_core.input_layer(features, columns)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
@@ -626,7 +626,7 @@
     cols_to_outs = {}
     feature_column_ops.input_from_feature_columns(
         features, columns, cols_to_outs=cols_to_outs)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       for column in columns:
@@ -637,7 +637,7 @@
     features = {"price": constant_op.constant([[20.], [110], [-3]])}
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [real_valued])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), features["price"].eval())
       # Verify cross compatibility: Core builder output should equal to contrib.
       self.assertAllClose(output.eval(),
@@ -650,7 +650,7 @@
     }
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [real_valued])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), features["price"].eval())
       # Verify cross compatibility: Core builder output should equal to contrib.
       self.assertAllClose(output.eval(),
@@ -662,7 +662,7 @@
     rating = np.array([[0., 1., 2., -1.],
                        [3., 4., 5., 6.]])
     features = {"rating": constant_op.constant(rating)}
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = sess.run(feature_column_ops.input_from_feature_columns(
           features, [var_len_real_valued]))
     self.assertAllClose(rating, output)
@@ -673,7 +673,7 @@
     rating = np.array([[0, 1, 2, -1],
                        [3, 4, 5, 6]])
     features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)}
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = sess.run(feature_column_ops.input_from_feature_columns(
           features, [var_len_real_valued]))
     self.assertAllClose(rating.astype(np.float32), output)
@@ -684,7 +684,7 @@
     features = {"price": constant_op.constant([[20.], [110], [-3]])}
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [real_valued])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), features["price"].eval() - 2)
       # Verify cross compatibility: Core builder output should equal to contrib.
       self.assertAllClose(output.eval(),
@@ -698,7 +698,7 @@
     }
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [real_valued])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), features["price"].eval() - 2)
       # Verify cross compatibility: Core builder output should equal to contrib.
       self.assertAllClose(output.eval(),
@@ -713,7 +713,7 @@
     features = {"price": constant_op.constant([[20.], [110], [-3]])}
     output = feature_column_ops.input_from_feature_columns(features, [bucket])
     expected = [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), expected)
       self.assertAllClose(output.eval(),
                           fc_core.input_layer(features, [bucket]).eval())
@@ -729,7 +729,7 @@
     output = feature_column_ops.input_from_feature_columns(features, [bucket])
     expected = [[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0],
                 [1, 0, 0, 0, 1, 0, 0, 0]]
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(output.eval(), expected)
       self.assertAllClose(output.eval(),
                           fc_core.input_layer(features, [bucket]).eval())
@@ -752,7 +752,7 @@
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [one_hot_column])
     output_core = fc_core.input_layer(features, [one_hot_column])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
@@ -773,7 +773,7 @@
                                                            [one_hot_sparse])
     output_core = fc_core.input_layer(features, [one_hot_sparse])
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
@@ -794,7 +794,7 @@
                                                            [one_hot_sparse])
     output_core = fc_core.input_layer(features, [one_hot_sparse])
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
@@ -816,7 +816,7 @@
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [one_hot_sparse])
     output_core = fc_core.input_layer(features, [one_hot_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
                           output.eval())
@@ -834,7 +834,7 @@
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [one_hot_sparse])
     output_core = fc_core.input_layer(features, [one_hot_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual([3, 10], output.eval().shape)
@@ -852,7 +852,7 @@
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embeded_sparse])
     output_core = fc_core.input_layer(features, [embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(output.eval().shape, [4, 10])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -878,7 +878,7 @@
         features, [embedded_sparse], weight_collections=["my_collection_core"])
     weights_core = ops.get_collection("my_collection_core")
     grad_core = gradients_impl.gradients(output_core, weights_core)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       gradient_values = []
       gradient_values_core = []
@@ -907,7 +907,7 @@
                                                            [embeded_sparse])
     output_core = fc_core.input_layer(features, [embeded_sparse])
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       output_eval = output.eval()
       self.assertAllEqual(output_eval.shape, [2, 10])
@@ -935,7 +935,7 @@
 
     # Makes sure that trying to use different initializers with the same
     # embedding column explicitly fails.
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError,
           "Duplicate feature column key found for column: wire_embedding"):
@@ -961,7 +961,7 @@
                                                            [embeded_sparse])
     output_core = fc_core.input_layer(features, [embeded_sparse])
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(output.eval().shape, [2, 10])
@@ -986,7 +986,7 @@
     embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1005,7 +1005,7 @@
     embeded_sparse = feature_column.embedding_column(crossed, 10)
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(output.eval().shape, [2, 10])
 
@@ -1016,7 +1016,7 @@
         indices=[[0, 0], [1, 0], [1, 1]],
         dense_shape=[2, 2])
     features = {"wire": wire_tensor}
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, "Error creating input layer for column: wire"):
         variables_lib.global_variables_initializer().run()
@@ -1035,7 +1035,7 @@
         indices=[[0, 0], [1, 0], [1, 1]],
         dense_shape=[2, 2])
     features = {"ids": ids_tensor, "weights": weights_tensor}
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError,
           "Error creating input layer for column: ids_weighted_by_weights"):
@@ -1053,7 +1053,7 @@
         indices=[[0, 0], [1, 0], [1, 1]],
         dense_shape=[2, 2])
     features = {"aaa": wire_tensor, "bbb": wire_tensor}
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, "Error creating input layer for column: aaa_X_bbb"):
         variables_lib.global_variables_initializer().run()
@@ -1080,7 +1080,7 @@
         hashed_sparse, 10, initializer=init_ops.constant_initializer(133.7))
     output = feature_column_ops.input_from_feature_columns(
         features, [real_valued, bucket, embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       # size of output = 3 (real_valued) + 2 * 4 (bucket) + 10 (embedding) = 21
       self.assertAllEqual(output.eval().shape, [3, 21])
@@ -1099,7 +1099,7 @@
         initializer=init_ops.ones_initializer())
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       # score: (number of values)
       self.assertAllEqual(output.eval(), [[1.], [2.], [0.]])
@@ -1119,7 +1119,7 @@
         max_norm=0.5)
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embedded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       # score: (number of values * 0.5)
       self.assertAllClose(output.eval(), [[0.5], [1.], [0.]])
@@ -1144,7 +1144,7 @@
         initializer=init_ops.ones_initializer())
     output = feature_column_ops.input_from_feature_columns(features,
                                                            [embeded_sparse])
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       # score: (sum of weights)
@@ -1236,7 +1236,7 @@
     # There should be one trainable variables for sparse_2
     self.assertEqual(1, len(variables_lib.trainable_variables()))
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       output_1_eval = output_1.eval()
       output_2_eval = output_2.eval()
@@ -1295,7 +1295,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [measurement_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       model_inputs = sess.run(model_input_tensor)
     self.assertAllClose(measurement_input, model_inputs)
 
@@ -1305,7 +1305,7 @@
     rating = np.array([[0., 1., 2., -1.],
                        [3., 4., 5., 6.]])
     features = {"rating": constant_op.constant(rating)}
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = sess.run(
           feature_column_ops.sequence_input_from_feature_columns(
               features, [var_len_real_valued]))
@@ -1329,7 +1329,7 @@
     expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
     reshaped_measurements = np.reshape(measurement_input, expected_shape)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       model_inputs = sess.run(model_input_tensor)
 
     self.assertAllClose(reshaped_measurements, model_inputs)
@@ -1350,7 +1350,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [measurement_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       model_inputs = sess.run(model_input_tensor)
     self.assertAllClose(normalizer(measurement_input), model_inputs)
 
@@ -1373,7 +1373,7 @@
     expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
     reshaped_measurements = np.reshape(measurement_input, expected_shape)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       model_inputs = sess.run(model_input_tensor)
 
     self.assertAllClose(normalizer(reshaped_measurements), model_inputs)
@@ -1395,7 +1395,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [one_hot_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input = sess.run(model_input_tensor)
@@ -1429,7 +1429,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [one_hot_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input = sess.run(model_input_tensor)
@@ -1459,7 +1459,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [embedded_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input = sess.run(model_input_tensor)
@@ -1488,7 +1488,7 @@
     model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
         columns_to_tensors, [embedded_column])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input = sess.run(model_input_tensor)
@@ -1518,7 +1518,7 @@
     embedding_weights = ops.get_collection("my_collection")
     gradient_tensor = gradients_impl.gradients(model_input_tensor,
                                                embedding_weights)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
@@ -1585,7 +1585,7 @@
         columns_to_tensors, model_input_columns)
     self.assertEqual(dtypes.float32, model_input_tensor.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       model_input = sess.run(model_input_tensor)
@@ -1622,7 +1622,7 @@
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [hashed_sparse], num_outputs=5)
     logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -1640,7 +1640,7 @@
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [hashed_sparse], num_outputs=5)
     logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -1654,7 +1654,7 @@
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [hashed_sparse], num_outputs=5)
     logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -1676,7 +1676,7 @@
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [weighted_ids], num_outputs=5)
     logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1695,7 +1695,7 @@
         features, [weighted_ids], num_outputs=5)
     logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
 
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1716,7 +1716,7 @@
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [crossed], num_outputs=5)
     logits_core = fc_core.linear_model(features, [crossed], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [2, 5])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -1730,7 +1730,7 @@
         dense_shape=[2, 2])
     features = {"wire": wire_tensor}
     embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, "Error creating weighted sum for column: wire_embedding"):
         variables_lib.global_variables_initializer().run()
@@ -1756,7 +1756,7 @@
               features, [movies], num_outputs=1))
       logits_core = fc_core.linear_model(features, [movies])
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.initialize_all_variables().run()
         lookup_ops.tables_initializer().run()
 
@@ -1776,7 +1776,7 @@
     }
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [real_valued], num_outputs=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [3, 5])
 
@@ -1789,7 +1789,7 @@
     }
     logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
         features, [bucket], num_outputs=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(logits.eval().shape, [3, 5])
 
@@ -1814,7 +1814,7 @@
         features, [real_valued, bucket, hashed_sparse, crossed], num_outputs=5)
     output_core = fc_core.linear_model(
         features, [real_valued, bucket, hashed_sparse, crossed], units=5)
-    with self.test_session():
+    with self.cached_session():
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(output.eval().shape, [3, 5])
       # Verify cross compatibility: Core builder output should equal to contrib.
@@ -1837,7 +1837,7 @@
       output, column_to_variable, bias = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [age, language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -1877,7 +1877,7 @@
               features, [country, language], num_outputs=1))
       # Assert that only a single weight is created.
       self.assertEqual(len(variables), 1)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -1941,7 +1941,7 @@
       output, column_to_variable, bias = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [weighted_language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -1969,7 +1969,7 @@
       output, column_to_variable, bias = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -1992,7 +1992,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [movies], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2026,7 +2026,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country_language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2050,7 +2050,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [language_language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2083,7 +2083,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country_language], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2124,7 +2124,7 @@
                 features, [country, language, country_language],
                 num_outputs=1,
                 scope=scope))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2161,7 +2161,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country, age, incomes], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2197,7 +2197,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country, age, height, incomes], num_outputs=5))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2228,7 +2228,7 @@
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [bucket], num_outputs=1))
       output_core = fc_core.linear_model(features, [bucket])
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         # Cross compatibility: Core builder output should equal to contrib.
@@ -2259,7 +2259,7 @@
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [bucket, country], num_outputs=1))
       output_core = fc_core.linear_model(features, [bucket, country])
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         # Cross compatibility: Core builder output should equal to contrib.
@@ -2290,7 +2290,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [bucket, country], num_outputs=5))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2326,7 +2326,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country_price], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2365,7 +2365,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [country_language_price], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2389,7 +2389,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [product], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         product_weights = column_to_variable[product][0]
@@ -2404,7 +2404,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [product], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         product_weights = column_to_variable[product][0]
@@ -2419,7 +2419,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [product], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         product_weights = column_to_variable[product][0]
@@ -2440,7 +2440,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [product], num_outputs=1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         product_weights = column_to_variable[product][0]
@@ -2452,7 +2452,7 @@
       features = {"age": constant_op.constant([[10.], [20.], [30.], [40.]])}
       output, _, bias = feature_column_ops.weighted_sum_from_feature_columns(
           features, [feature_column.real_valued_column("age")], num_outputs=3)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         sess.run(bias.assign([0.1, 0.2, 0.3]))
@@ -2466,7 +2466,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [column], num_outputs=3))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         weights = column_to_variable[column][0]
@@ -2490,7 +2490,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [column], num_outputs=3))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
         weights = column_to_variable[column][0]
@@ -2516,7 +2516,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [column], num_outputs=3))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2556,7 +2556,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [column], num_outputs=3))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2585,7 +2585,7 @@
       output, column_to_variable, _ = (
           feature_column_ops.weighted_sum_from_feature_columns(
               features, [column], num_outputs=3))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         variables_lib.global_variables_initializer().run()
         lookup_ops.tables_initializer().run()
 
@@ -2651,7 +2651,7 @@
         feature_columns=[bucket, wire_cast])
     self.assertIn(bucket, output)
     self.assertIn(wire_cast, output)
-    with self.test_session():
+    with self.cached_session():
       lookup_ops.tables_initializer().run()
       self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
       self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
@@ -2713,7 +2713,7 @@
     self.assertIn("measurements", seq)
     self.assertIsInstance(seq["measurements"], ops.Tensor)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       location_val, wire_cast_val, measurement_val = sess.run(
           [ctx["location"], seq["wire_cast"], seq["measurements"]])
 
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index eaaf9f8..d90d6ec 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -201,7 +201,7 @@
       b2 = feature_column_ops.input_from_feature_columns({
           b[1]: input_tensor_c2
       }, [b[1]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       b1_value = b1.eval()
       b2_value = b2.eval()
@@ -230,7 +230,7 @@
       e1 = feature_column_ops.input_from_feature_columns({
           e[0]: input_tensor_c1
       }, [e[0]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       d1_value = d1.eval()
       e1_value = e1.eval()
@@ -340,7 +340,7 @@
       with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
         one_hot_output = one_hot._to_dnn_input_layer(
             id_tensor, output_rank=output_rank)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         one_hot_value = sess.run(one_hot_output)
         expected_shape = (id_tensor_shape[:output_rank - 1] + [vocab_size])
         self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -376,7 +376,7 @@
       one_hot_output_shape = one_hot_output.get_shape().as_list()
       expected_shape = id_tensor_shape[:-1] + [vocab_size]
       self.assertEquals(expected_shape, one_hot_output_shape)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         one_hot_value = sess.run(one_hot_output)
         self.assertEquals(expected_shape, list(one_hot_value.shape))
 
@@ -399,7 +399,7 @@
     expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0.,
                           0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
                          [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       one_hot_value = sess.run(one_hot_output)
     self.assertTrue(np.array_equal(one_hot_value, expected))
 
@@ -440,7 +440,7 @@
     }
     one_hot_tensor = feature_column_ops.input_from_feature_columns(
         features, [one_hot])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       sess.run(lookup_ops.tables_initializer())
       self.assertAllEqual([[2., 6., 0.]], one_hot_tensor.eval())
@@ -451,7 +451,7 @@
     features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
     one_hot_tensor = feature_column_ops.input_from_feature_columns(
         features, [one_hot])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       sess.run(lookup_ops.tables_initializer())
       self.assertAllEqual([[1., 1., 0.]], one_hot_tensor.eval())
@@ -603,7 +603,7 @@
         real_valued_output = real_valued_column._to_dnn_input_layer(
             constant_op.constant(real_valued_input, dtype=dtypes.float32),
             output_rank=output_rank)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         real_valued_eval = sess.run(real_valued_output)
       expected_shape = (
           input_shape[:output_rank - 1] +
@@ -797,7 +797,7 @@
     sparse_column.insert_transformed_feature(features)
     sparse_output = features[sparse_column]
     expected_shape = [batch_size, 1]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_result = sess.run(sparse_output)
     self.assertEquals(expected_shape, list(sparse_result.dense_shape))
 
@@ -1110,7 +1110,7 @@
     ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
     checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       saved_embedding = embeddings.eval()
       save.save(sess, checkpoint_path)
@@ -1131,7 +1131,7 @@
           embedding_col_initialized: input_tensor
       }, [embedding_col_initialized])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       loaded_embedding = pretrained_embeddings.eval()
 
@@ -1176,7 +1176,7 @@
     ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
     checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       sess.run(assign_op)
       saved_col_weights = col_weights[crossed_col][0].eval()
@@ -1201,7 +1201,7 @@
           }, [crossed_col_initialized], 1))
       col_weights_from_ckpt = col_weights[crossed_col_initialized][0]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       loaded_col_weights = col_weights_from_ckpt.eval()
 
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 04668f1..a82d4c1 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -3109,7 +3109,7 @@
     inputs: Tensor input
     num_units: Specifies how many features will remain after maxout
       in the `axis` dimension (usually channel).
-      This must be multiple of number of `axis`.
+      This must be a factor of number of features.
     axis: The dimension where max pooling will be performed. Default is the
     last dimension.
     scope: Optional scope for variable_scope.
@@ -3128,7 +3128,7 @@
       raise ValueError('number of features({}) is not '
                        'a multiple of num_units({})'.format(
                            num_channels, num_units))
-    shape[axis] = -1
+    shape[axis] = num_units
     shape += [num_channels // num_units]
 
     # Dealing with batches with arbitrary sizes
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index eee9086..85af9de 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -281,7 +281,7 @@
 
   def testCreate(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3))
       output = _layers.bias_add(images)
       self.assertEqual(output.op.name, 'BiasAdd/BiasAdd')
@@ -289,7 +289,7 @@
 
   def testCreateWithActivation(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = _layers.bias_add(images, activation_fn=nn_ops.relu)
       self.assertEqual(output.op.name, 'BiasAdd/Relu')
@@ -298,7 +298,7 @@
   def testCreateDimensions(self):
     dims = (2, 3, 4)
     shape = [5, 2, 3, 4]
-    with self.test_session():
+    with self.cached_session():
       for d in dims:
         input_shape = shape[:d]
         inputs = random_ops.random_uniform(input_shape, seed=1)
@@ -311,7 +311,7 @@
 class ConvolutionTest(test.TestCase):
 
   def testInvalidShape(self):
-    with self.test_session():
+    with self.cached_session():
       images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1)
       with self.assertRaisesRegexp(
           ValueError, 'Convolution expects input with rank 5, got 4'):
@@ -323,14 +323,14 @@
 
   def testInvalidDataFormat(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       with self.assertRaisesRegexp(ValueError, 'data_format'):
         layers_lib.convolution2d(images, 32, 3, data_format='CHWN')
 
   def testCreateConv(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
       output = layers_lib.convolution2d(images, 32, [3, 3])
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -342,7 +342,7 @@
 
   def testCreateConvNCHW(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
       output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW')
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -354,7 +354,7 @@
 
   def testCreateSquareConv(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, 3)
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -362,7 +362,7 @@
 
   def testCreateConvWithTensorShape(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, images.get_shape()[1:3])
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -370,7 +370,7 @@
 
   def testCreateFullyConv(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 32), seed=1)
       output = layers_lib.convolution2d(
           images, 64, images.get_shape()[1:3], padding='VALID')
@@ -381,7 +381,7 @@
 
   def testFullyConvWithCustomGetter(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       called = [0]
 
       def custom_getter(getter, *args, **kwargs):
@@ -395,7 +395,7 @@
 
   def testCreateVerticalConv(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 4), seed=1)
       output = layers_lib.convolution2d(images, 32, [3, 1])
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -407,7 +407,7 @@
 
   def testCreateHorizontalConv(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 4), seed=1)
       output = layers_lib.convolution2d(images, 32, [1, 3])
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -417,7 +417,7 @@
 
   def testCreateConvWithStride(self):
     height, width = 6, 8
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, [3, 3], stride=2)
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -427,7 +427,7 @@
   def testCreateConvCreatesWeightsAndBiasesVars(self):
     height, width = 7, 9
     images = random_ops.random_uniform((5, height, width, 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('conv1/weights'))
       self.assertFalse(variables.get_variables('conv1/biases'))
       layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
@@ -436,7 +436,7 @@
 
   def testCreateConvWithScope(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
       self.assertEqual(output.op.name, 'conv1/Relu')
@@ -453,14 +453,14 @@
 
   def testCreateConvWithoutActivation(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, [3, 3], activation_fn=None)
       self.assertEqual(output.op.name, 'Conv/BiasAdd')
 
   def testCreateConvValid(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.convolution2d(images, 32, [3, 3], padding='VALID')
       self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
@@ -468,7 +468,7 @@
   def testCreateConvWithWD(self):
     height, width = 7, 9
     weight_decay = 0.01
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       regularizer = regularizers.l2_regularizer(weight_decay)
       layers_lib.convolution2d(
@@ -481,7 +481,7 @@
 
   def testCreateConvNoRegularizers(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       layers_lib.convolution2d(images, 32, [3, 3])
       self.assertEqual(
@@ -489,7 +489,7 @@
 
   def testReuseVars(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
       self.assertEqual(len(variables.get_variables()), 2)
@@ -498,7 +498,7 @@
 
   def testNonReuseVars(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       layers_lib.convolution2d(images, 32, [3, 3])
       self.assertEqual(len(variables.get_variables()), 2)
@@ -507,7 +507,7 @@
 
   def testReuseConvWithWD(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       weight_decay = regularizers.l2_regularizer(0.01)
       with arg_scope(
@@ -523,7 +523,7 @@
 
   def testConvWithBatchNorm(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 32), seed=1)
       with arg_scope(
           [layers_lib.convolution2d],
@@ -539,7 +539,7 @@
 
   def testReuseConvWithBatchNorm(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 32), seed=1)
       with arg_scope(
           [layers_lib.convolution2d],
@@ -557,7 +557,7 @@
   def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
     height, width = 7, 9
     images = random_ops.random_uniform((5, height, width, 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('conv1/weights'))
       self.assertFalse(variables.get_variables('conv1/biases'))
       layers_lib.convolution2d(images, 32, [3, 3], rate=2, scope='conv1')
@@ -573,7 +573,7 @@
     output = layers_lib.convolution2d(
         images, num_filters, [3, 3], rate=2, padding='SAME')
     self.assertListEqual(list(output.get_shape().as_list()), expected_size)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -587,7 +587,7 @@
     output = layers_lib.convolution2d(
         images, num_filters, [3, 3], rate=2, padding='VALID')
     self.assertListEqual(list(output.get_shape().as_list()), expected_size)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -601,7 +601,7 @@
     output = layers_lib.convolution2d(
         images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
     self.assertListEqual(list(output.get_shape().as_list()), expected_size)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEquals(output.op.name, 'Conv/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -612,7 +612,7 @@
     expected_size = [None, None, None, num_filters]
     expected_size_dynamic = [5, 7, 9, num_filters]
 
-    with self.test_session():
+    with self.cached_session():
       images = array_ops.placeholder(np.float32,
                                      [None, None, None, input_size[3]])
       output = layers_lib.convolution2d(
@@ -651,7 +651,7 @@
     expected_size = [None, None, None, num_filters]
     expected_size_dynamic = [5, 5, 7, num_filters]
 
-    with self.test_session():
+    with self.cached_session():
       images = array_ops.placeholder(np.float32,
                                      [None, None, None, input_size[3]])
       output = layers_lib.convolution2d(
@@ -670,7 +670,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.convolution2d(
         images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'conv7/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -688,7 +688,7 @@
         padding='VALID',
         activation_fn=None,
         scope='conv7')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'conv7/BiasAdd')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -712,7 +712,7 @@
 
   def testInvalidDataFormat(self):
     height, width = 7, 9
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       with self.assertRaisesRegexp(
           ValueError, 'data_format has to be either NCHW or NHWC.'):
@@ -915,7 +915,7 @@
         images, num_filters, [3, 3], stride=1, padding='SAME')
     self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertListEqual(list(output.eval().shape), expected_size)
 
@@ -929,7 +929,7 @@
         images, num_filters, [3, 3], stride=1, padding='VALID')
     self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertListEqual(list(output.eval().shape), expected_size)
 
@@ -944,7 +944,7 @@
     self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
     self.assertListEqual(list(output.get_shape().as_list()), expected_size)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertListEqual(list(output.eval().shape), expected_size)
 
@@ -958,7 +958,7 @@
         images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
     self.assertListEqual(list(output.get_shape().as_list()), expected_size)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -971,7 +971,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -984,7 +984,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -997,7 +997,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1010,7 +1010,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 4], stride=[2, 1], padding='VALID')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1023,7 +1023,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 4], stride=[2, 4], padding='VALID')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1036,7 +1036,7 @@
     images = random_ops.random_uniform(input_size, seed=1)
     output = layers_lib.conv2d_transpose(
         images, num_filters, [2, 4], stride=[2, 5], padding='VALID')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1083,7 +1083,7 @@
         images, num_filters, [3, 3], stride=[2, 2], padding='VALID')
     self.assertListEqual(output.get_shape().as_list(), expected_size)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
       eval_output = output.eval({images: np.zeros(input_size, np.float32)})
@@ -1095,7 +1095,7 @@
     expected_size = [None, None, None, num_filters]
     expected_size_dynamic = [5, 18, 22, num_filters]
 
-    with self.test_session():
+    with self.cached_session():
       images = array_ops.placeholder(np.float32,
                                      [None, None, None, input_size[3]])
       output = layers_lib.conv2d_transpose(
@@ -1116,7 +1116,7 @@
         images, num_filters, [3, 3], stride=2, padding='VALID', scope='conv7')
     self.assertEqual(output.op.name, 'conv7/Relu')
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertListEqual(list(output.eval().shape), expected_size)
 
@@ -1135,7 +1135,7 @@
         scope='conv7')
     self.assertEqual(output.op.name, 'conv7/BiasAdd')
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertListEqual(list(output.eval().shape), expected_size)
 
@@ -1146,7 +1146,7 @@
     stride = 2
     padding = 'VALID'
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(input_size, seed=1)
       output_deconv = layers_lib.conv2d_transpose(
           images,
@@ -1184,7 +1184,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(horz_gradients)
       expected = np.zeros((1, 10, 9, 1))
@@ -1201,7 +1201,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(
           horz_gradients, feed_dict={
@@ -1225,7 +1225,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(horz_gradients)
 
@@ -1245,7 +1245,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(horz_gradients)
 
@@ -1267,7 +1267,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(horz_gradients)
 
@@ -1283,12 +1283,12 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(vert_gradients)
       expected = np.zeros((1, 9, 10, 1))
 
-      self.assertAllEqual(result, expected)
+      self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
 
   def testVertConvWithVaryingImage(self):
     image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9'))
@@ -1306,7 +1306,7 @@
         activation_fn=None)
     init_op = variables_lib.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       result = sess.run(vert_gradients)
 
@@ -1314,7 +1314,7 @@
 
   def testConv1dShape(self):
     width = 7
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, width, 3), seed=1)
       output = layers_lib.convolution1d(images, 32, 3)
       self.assertEqual(output.op.name, 'Conv/Relu')
@@ -1322,7 +1322,7 @@
 
   def testConvInferSpatialDims(self):
     depth, height, width = 7, 9, 11
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
       output = layers_lib.convolution(images, 32, [3])
       self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
@@ -1344,7 +1344,7 @@
     sparse = _layers.dense_to_sparse(tensor)
     dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
                                        sparse.values)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       constant = sess.run(dense)
       self.assertAllEqual(expected_constant, constant)
 
@@ -1353,7 +1353,7 @@
 
   def testCreateDropout(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3))
       output = _layers.dropout(images)
       self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
@@ -1362,7 +1362,7 @@
 
   def testCreateDropoutWithConstantTrue(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       is_training = constant_op.constant(True)
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = _layers.dropout(images, is_training=is_training)
@@ -1370,7 +1370,7 @@
 
   def testCreateDropoutWithConstantFalse(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       is_training = constant_op.constant(False)
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = _layers.dropout(images, is_training=is_training)
@@ -1378,7 +1378,7 @@
 
   def testCreateDropoutWithPlaceholder(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = _layers.dropout(images, is_training=is_training)
@@ -1387,7 +1387,7 @@
 
   def testCollectOutputs(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = _layers.dropout(images, outputs_collections='outputs')
       c_output = ops.get_collection('outputs')[0]
@@ -1396,7 +1396,7 @@
 
   def testDropout(self):
     height, width = 10, 10
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1409,7 +1409,7 @@
   def testDropoutSeed(self):
     """Test that providing the same seed produces the same result."""
     height, width = 10, 10
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output1 = _layers.dropout(images, seed=1)
@@ -1418,7 +1418,7 @@
 
   def testCreateDropoutNoTraining(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1431,7 +1431,7 @@
 
   def testCreateFCFollowByDropout(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.fully_connected(images, 50)
@@ -1445,7 +1445,7 @@
 
   def testCreateFCWithDropout(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.fully_connected(
@@ -1475,7 +1475,7 @@
 
   def testCollectOutputs(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3))
       output = _layers.flatten(images, outputs_collections='outputs')
       c_output = ops.get_collection('outputs')[0]
@@ -1484,7 +1484,7 @@
 
   def testFlatten4D(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.flatten(images)
@@ -1494,7 +1494,7 @@
 
   def testFlatten3D(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width), seed=1, name='images')
       output = _layers.flatten(images)
@@ -1504,7 +1504,7 @@
 
   def testFlattenBatchSize(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       inputs = array_ops.placeholder(dtypes.int32, (None, height, width, 3))
@@ -1516,7 +1516,7 @@
 
   def testUnknownDims(self):
     height = width = depth = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform(
           (5, height, width, depth), seed=1, name='images')
       inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None))
@@ -1551,7 +1551,7 @@
       flattened_t = _layers._inner_flatten(inputs, new_rank)
       static_shape = flattened_t.get_shape().as_list()
       self.assertEqual(static_shape, expected_new_shape)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         flattened = sess.run(flattened_t)
       np.testing.assert_array_equal(expected_flattened, flattened)
 
@@ -1571,7 +1571,7 @@
 
       flattened_t = _layers._inner_flatten(inputs_t, new_rank)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         flattened = sess.run(flattened_t)
 
       np.testing.assert_array_equal(expected_indices, flattened.indices)
@@ -1641,7 +1641,7 @@
 
   def testCreateFCWithScope(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       output = _layers.fully_connected(inputs, 32, scope='fc1')
       self.assertEqual(output.op.name, 'fc1/Relu')
@@ -1659,7 +1659,7 @@
   def testCreateFcCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('fc1/weights'))
       self.assertFalse(variables.get_variables('fc1/biases'))
       _layers.fully_connected(inputs, 32, scope='fc1')
@@ -1669,7 +1669,7 @@
   def testReuseVars(self):
     height, width = 3, 3
     inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       _layers.fully_connected(inputs, 32, scope='fc1')
       self.assertEqual(len(variables.get_variables('fc1')), 2)
       _layers.fully_connected(inputs, 32, scope='fc1', reuse=True)
@@ -1678,7 +1678,7 @@
   def testNonReuseVars(self):
     height, width = 3, 3
     inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       _layers.fully_connected(inputs, 32)
       self.assertEqual(len(variables.get_variables('fully_connected')), 2)
       _layers.fully_connected(inputs, 32)
@@ -1713,14 +1713,14 @@
 
   def testCreateFCWithoutActivation(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       output = _layers.fully_connected(inputs, 32, activation_fn=None)
       self.assertEqual(output.op.name, 'fully_connected/BiasAdd')
 
   def testCreateFCWithWD(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       weight_decay = regularizers.l2_regularizer(0.01)
       _layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
@@ -1732,7 +1732,7 @@
 
   def testCreateFCWithBD(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       bias_decay = regularizers.l2_regularizer(0.01)
       _layers.fully_connected(inputs, 32, biases_regularizer=bias_decay)
@@ -1744,7 +1744,7 @@
 
   def testCreateNoRegularizers(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       _layers.fully_connected(inputs, 32)
       self.assertEqual(
@@ -1752,7 +1752,7 @@
 
   def testReuseFCWithWD(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
       weight_decay = regularizers.l2_regularizer(0.01)
       _layers.fully_connected(
@@ -1768,7 +1768,7 @@
 
   def testFCWithBatchNorm(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height * width * 3), seed=1)
       with arg_scope(
           [_layers.fully_connected],
@@ -1786,7 +1786,7 @@
 
   def testReuseFCWithBatchNorm(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height * width * 3), seed=1)
       with arg_scope(
           [_layers.fully_connected],
@@ -1844,7 +1844,7 @@
     if dtype is None:
       dtype = dtypes.float32
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3)).astype(
           dtype.as_numpy_dtype)
       output = _layers.batch_norm(images, fused=fused)
@@ -1866,7 +1866,7 @@
 
   def _testCreateOpBetaRegularizer(self, fused=True):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       reg = lambda x: 0.1 * math_ops.reduce_sum(x)
       images = np.random.uniform(size=(5, height, width, 3)).astype('f')
       _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused)
@@ -1883,7 +1883,7 @@
 
   def _testCreateOpGammaRegularizer(self, fused=True):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       reg = lambda x: 0.1 * math_ops.reduce_sum(x)
       images = np.random.uniform(size=(5, height, width, 3)).astype('f')
       _layers.batch_norm(
@@ -1901,7 +1901,7 @@
 
   def testCreateVariables(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.batch_norm(images, scale=True)
       beta = variables.get_variables_by_name('beta')[0]
@@ -1915,7 +1915,7 @@
 
   def testMovingAverageVariables(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.batch_norm(images, scale=True)
       self.assertEqual(len(variables.get_model_variables()), 4)
@@ -1926,7 +1926,7 @@
 
   def testMovingAverageVariablesZeroDebias(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.batch_norm(
           images, scale=True, zero_debias_moving_mean=True, fused=False)
@@ -1943,7 +1943,7 @@
 
   def testUpdatesCollection(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.batch_norm(images, updates_collections='my_update_ops')
       update_layers = ops.get_collection('my_update_ops')
@@ -1971,7 +1971,7 @@
 
   def testReuseVariables(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.batch_norm(images, scale=True, scope='bn')
       _layers.batch_norm(images, scale=True, scope='bn', reuse=True)
@@ -1986,7 +1986,7 @@
 
   def testReuseUpdateOps(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       with arg_scope([_layers.batch_norm], updates_collections='update_ops'):
         _layers.batch_norm(images, scope='bn')
@@ -1996,7 +1996,7 @@
 
   def testCreateMovingVars(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _ = _layers.batch_norm(images)
       moving_mean = variables.get_variables('BatchNorm/moving_mean')
@@ -2029,7 +2029,7 @@
     moving_variance = variables.get_variables_by_name('moving_variance')[0]
     biased = variables.get_variables_by_name('biased')[0]
     local_step = variables.get_variables_by_name('local_step')[0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       self.assertAllClose(local_step.eval(), 0)
       self.assertAllClose(moving_mean.eval(), [0] * channels)
@@ -2213,7 +2213,7 @@
 
   def _testEvalMovingVars(self, zero_debias_moving_mean=False):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       image_shape = (10, height, width, 3)
       image_values = np.random.rand(*image_shape)
       expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2264,7 +2264,7 @@
     height, width = 3, 3
     batch_size = 10
     channels = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       image_shape = (batch_size, height, width, channels)
       image_values = np.random.rand(*image_shape)
       expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2435,7 +2435,7 @@
 
   def testNoUpdatesWhenIsTrainingFalse(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       image_shape = (10, height, width, 3)
       image_values = np.random.rand(*image_shape)
       images = constant_op.constant(
@@ -2460,7 +2460,7 @@
 
   def testNoneUpdatesCollectionNoTraining(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       image_shape = (10, height, width, 3)
       image_values = np.random.rand(*image_shape)
       images = constant_op.constant(
@@ -2647,7 +2647,7 @@
   def testCustomInitializer(self):
     height, width = 3, 3
     channels = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = (np.ones((5, height, width, channels)) * 9.0).astype('f')
       beta = init_ops.constant_initializer(
           (np.ones(channels) * 5.0).astype('f'))
@@ -2728,7 +2728,7 @@
 
   def testBatchNormBeta(self):
     # Test case for 11673
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
       _layers.batch_norm(
           a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
@@ -2739,7 +2739,7 @@
 
   def testVariablesAreFloat32(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, dtype=dtypes.float16)
       _layers.batch_norm(images, scale=True)
@@ -2824,7 +2824,7 @@
 
   def testCreateOp(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3))
       output = _layers.layer_norm(images)
       self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm'))
@@ -2832,7 +2832,7 @@
 
   def testCreateVariables(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.layer_norm(images)
       beta = variables.get_variables_by_name('beta')[0]
@@ -2842,7 +2842,7 @@
 
   def testReuseVariables(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       _layers.layer_norm(images, scope='ln')
       _layers.layer_norm(images, scope='ln', reuse=True)
@@ -2853,7 +2853,7 @@
 
   def testReuseVars(self):
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       image_shape = (10, height, width, 3)
       image_values = np.random.rand(*image_shape)
       images = constant_op.constant(
@@ -2940,7 +2940,7 @@
   def _runGDN(self, x, shape, inverse, data_format):
     inputs = array_ops.placeholder(dtypes.float32, shape)
     outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       y, = sess.run([outputs], {inputs: x})
     return y
@@ -3152,14 +3152,14 @@
 class OneHotEncodingTest(test.TestCase):
 
   def testOneHotEncodingCreate(self):
-    with self.test_session():
+    with self.cached_session():
       labels = np.array([0, 1, 2])
       output = _layers.one_hot_encoding(labels, num_classes=3)
       self.assertEqual(output.op.name, 'OneHotEncoding/one_hot')
       self.assertListEqual(output.get_shape().as_list(), [3, 3])
 
   def testCollectOutputs(self):
-    with self.test_session():
+    with self.cached_session():
       labels = constant_op.constant([0, 1, 2])
       output = _layers.one_hot_encoding(
           labels, num_classes=3, outputs_collections='outputs')
@@ -3168,14 +3168,14 @@
       self.assertEqual(c_output, output)
 
   def testOneHotEncoding(self):
-    with self.test_session():
+    with self.cached_session():
       labels = constant_op.constant([0, 1, 2])
       one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
       output = _layers.one_hot_encoding(labels, num_classes=3)
       self.assertAllClose(output.eval(), one_hot_labels.eval())
 
   def testOneHotEncodingInt32(self):
-    with self.test_session():
+    with self.cached_session():
       labels = constant_op.constant([0, 1, 2], dtype=dtypes.int32)
       one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
       output = _layers.one_hot_encoding(labels, num_classes=3)
@@ -3186,7 +3186,7 @@
 
   def testRepeat(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
       output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3])
       self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu')
@@ -3194,7 +3194,7 @@
 
   def testRepeatWithScope(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.repeat(
@@ -3207,7 +3207,7 @@
 
   def testCreateConvInt32(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, dtype=dtypes.int32, maxval=12345)
       with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
@@ -3215,7 +3215,7 @@
 
   def testCreateConvFloat32(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, dtype=dtypes.float32)
       output = layers_lib.separable_conv2d(images, 32, [3, 3], 2)
@@ -3224,7 +3224,7 @@
 
   def testCreateDepthwiseConv(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(images, None, [3, 3], 2)
       self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
@@ -3233,7 +3233,7 @@
   def testCreateConvCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     images = random_ops.random_uniform((5, height, width, 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
       self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
       self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3245,7 +3245,7 @@
   def testCreateAtrousConvCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     images = random_ops.random_uniform((5, height, width, 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
       self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
       self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3257,7 +3257,7 @@
   def testCreateDepthwiseConvCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     images = random_ops.random_uniform((5, height, width, 3), seed=1)
-    with self.test_session():
+    with self.cached_session():
       self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
       self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
       self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3268,14 +3268,14 @@
 
   def testCreateConvWithScope(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(images, 32, [3, 3], 6, scope='conv1')
       self.assertEqual(output.op.name, 'conv1/Relu')
 
   def testCreateConvWithoutActivation(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(
           images, 32, [3, 3], 8, activation_fn=None)
@@ -3283,7 +3283,7 @@
 
   def testCreateConvValid(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(
           images, 32, [3, 3], 2, padding='VALID')
@@ -3291,7 +3291,7 @@
 
   def testCreateAtrousConvValid(self):
     height, width = 5, 5
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(
           images, 32, [3, 3], 2, padding='VALID', rate=2)
@@ -3299,7 +3299,7 @@
 
   def testCreateDepthwiseConvValid(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(
           images, None, [3, 3], 2, padding='VALID')
@@ -3307,7 +3307,7 @@
 
   def testCreateAtrousDepthwiseConvValid(self):
     height, width = 5, 5
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       output = layers_lib.separable_conv2d(
           images, None, [3, 3], 2, padding='VALID', rate=2)
@@ -3316,7 +3316,7 @@
   def testCreateConvWithWeightDecay(self):
     random_seed.set_random_seed(0)
     height, width = 3, 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       regularizer = regularizers.l2_regularizer(0.01)
       layers_lib.separable_conv2d(
@@ -3360,7 +3360,7 @@
 
   def testReuseConvWithWeightDecay(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       regularizer = regularizers.l2_regularizer(0.01)
       layers_lib.separable_conv2d(
@@ -3419,7 +3419,7 @@
         normalizer_params={},
         scope='conv1')
     init_op = variables_lib.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       images = np.random.rand(5, height, width, 3)
       sess.run(init_op)
       sess.run(net, feed_dict={images_placeholder: images})
@@ -3440,7 +3440,7 @@
 
   def testSepConvNCHW(self):
     for num_filters, correct_output_filters in zip((None, 5), (6, 5)):
-      with self.test_session():
+      with self.cached_session():
         batch, height, width = 4, 10, 12
         kernel_dim, stride = 3, 2
         images = random_ops.random_uniform((batch, 3, height, width), seed=1)
@@ -3462,7 +3462,7 @@
   """Simple tests of the scale_gradient function."""
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.array([42], np.float32)
       gradient_scale = np.array([2], np.float32)
 
@@ -3513,7 +3513,7 @@
     exp_prediction = np.array([[self.low, self.high], [0.5, 0.5],
                                [self.high, self.low]])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       prediction = sess.run(prediction)
       self.assertAllClose(exp_prediction, prediction)
 
@@ -3529,7 +3529,7 @@
     exp_prediction[1, 1, 1] = self.low
 
     prediction = _layers.softmax(logits)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       prediction = sess.run(prediction)
       self.assertAllClose(exp_prediction, prediction)
 
@@ -3547,7 +3547,7 @@
     exp_prediction[1, 1, 1] = self.low
 
     prediction = _layers.softmax(logit_placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       prediction = sess.run(prediction, feed_dict=feed_dict)
       self.assertAllClose(exp_prediction, prediction)
 
@@ -3575,7 +3575,7 @@
     features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
     np_features = np.zeros(batch_shape, dtype=np.float32)
     spatial_softmax = _layers.spatial_softmax(features)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3586,7 +3586,7 @@
     features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
     np_features = np.zeros(batch_shape, dtype=np.float32)
     spatial_softmax = _layers.spatial_softmax(features, data_format='NCHW')
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3613,7 +3613,7 @@
                                         nchannels)
 
     # Make sure expected location keypoints matches actual location keypoints.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3637,7 +3637,7 @@
                                         nchannels)
 
     # Make sure expected location keypoints matches actual location keypoints.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3669,7 +3669,7 @@
                                          batch_size, nchannels)
 
     # Make sure expected location keypoints matches actual location keypoints.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features1}
       tf_keypoints1 = sess.run(spatial_softmax, feed_dict)
@@ -3696,7 +3696,7 @@
                                         nchannels)
 
     # Make sure expected location keypoints matches actual location keypoints.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3719,7 +3719,7 @@
                                         nchannels)
 
     # Make sure expected location keypoints matches actual location keypoints.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3731,7 +3731,7 @@
     spatial_softmax = _layers.spatial_softmax(features)
     net = _layers.fully_connected(spatial_softmax, 10)
     np_features = np.zeros(batch_shape, dtype=np.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       feed_dict = {features: np_features}
       sess.run(net, feed_dict)
@@ -3741,7 +3741,7 @@
 
   def testStackFullyConnected(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = np.random.uniform(size=(5, height * width * 3))
       output = _layers.stack(images, _layers.fully_connected, [10, 20, 30])
       self.assertEqual(output.op.name, 'Stack/fully_connected_3/Relu')
@@ -3749,7 +3749,7 @@
 
   def testStackFullyConnectedFailOnReuse(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope('test', reuse=True):
         images = np.random.uniform(size=(5, height * width * 3))
         with self.assertRaises(ValueError):
@@ -3757,7 +3757,7 @@
 
   def testStackRelu(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height * width * 3), seed=1, name='images')
       output = _layers.stack(images, layers_lib.relu, [10, 20, 30])
@@ -3766,7 +3766,7 @@
 
   def testStackElu(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height * width * 3), seed=1, name='images')
       output = _layers.stack(images, layers_lib.elu, [10, 20, 30])
@@ -3775,7 +3775,7 @@
 
   def testStackConvolution2d(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.stack(
@@ -3788,7 +3788,7 @@
 
   def testStackWithScope(self):
     height, width = 3, 3
-    with self.test_session():
+    with self.cached_session():
       images = random_ops.random_uniform(
           (5, height, width, 3), seed=1, name='images')
       output = _layers.stack(
@@ -3817,7 +3817,7 @@
       del shape[dim]
       expected = np.ones(shape)
 
-      with self.test_session():
+      with self.cached_session():
         actual = norms.eval()
         self.assertAllClose(expected, actual, 1e-4, 1e-4)
 
@@ -3849,7 +3849,7 @@
       norms = math_ops.sqrt(
           math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
 
-      with self.test_session():
+      with self.cached_session():
         actual = norms.eval({image: placeholder_value})
         self.assertAllClose(expected, actual, 1e-4, 1e-4)
 
@@ -3875,7 +3875,7 @@
     x_np = np.random.random_sample(x_shape).astype(np.float32)
     for dim in range(len(x_shape)):
       y_np = self._PoincareNormalize(x_np, dim, epsilon)
-      with self.test_session():
+      with self.cached_session():
         x_tf = constant_op.constant(x_np, name='x')
         y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
         y_tf_eval = y_tf.eval()
@@ -3893,7 +3893,7 @@
     x_np = np.random.random_sample(x_shape).astype(np.float32)
     dim = [1, 2]
     y_np = self._PoincareNormalize(x_np, dim, epsilon)
-    with self.test_session():
+    with self.cached_session():
       x_tf = constant_op.constant(x_np, name='x')
       y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
       y_tf_eval = y_tf.eval()
@@ -3908,7 +3908,7 @@
     np.random.seed(1)
     x_np = np.random.random_sample(x_shape).astype(np.float64)
     for dim in range(len(x_shape)):
-      with self.test_session():
+      with self.cached_session():
         x_tf = constant_op.constant(x_np, name='x')
         y_tf = _layers.poincare_normalize(x_tf, dim)
         err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -4117,7 +4117,7 @@
     # Empty x is common if someone masks their input with tf.boolean_mask in
     # order to drop missing entries, and in a particular batch all entries are
     # missing.
-    with self.test_session():
+    with self.cached_session():
       x = np.array([]).reshape(0, 3)
       self.assertEqual(0, array_ops.size(x).eval())
       y = _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax)
@@ -4131,7 +4131,7 @@
     y = _layers.legacy_fully_connected(x, 1)
     # in the output we still only know the 2nd and 3rd dimensions statically.
     self.assertEqual(y.get_shape().as_list(), [None, 4, 1])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       # we can feed in input with first dimension 2
       shape_value = sess.run(
@@ -4162,7 +4162,7 @@
       self._unknown_dim_invalid_input(last_dim=None)
 
   def test_1d_invalid_input(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError,
                                    'rank of x must be at least 2 not: 1'):
         x = constant_op.constant([[]], shape=[0])
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index 55272e5..c8d3c91 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -106,7 +106,7 @@
     images = random_ops.random_uniform(image_shape, seed=1)
     output_train = normalization.instance_norm(images, scope='IN')
     output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       # output_train and output_eval should be the same.
       train_np, eval_np = sess.run([output_train, output_eval])
@@ -130,7 +130,7 @@
         inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
         output_op = normalization.instance_norm(
             inputs, center=False, scale=False, data_format=data_format)
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           sess.run(variables.global_variables_initializer())
           outputs = sess.run(output_op)
           # Make sure that there are no NaNs
@@ -287,7 +287,7 @@
     output_train = normalization.group_norm(images, groups=2, scope='IN')
     output_eval = normalization.group_norm(images, groups=2, scope='IN',
                                            reuse=True)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       # output_train and output_eval should be the same.
       train_np, eval_np = sess.run([output_train, output_eval])
@@ -349,7 +349,7 @@
             channels_axis=channels_axis,
             reduction_axes=reduction_axes,
             mean_close_to_zero=mean_close_to_zero)
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           sess.run(variables.global_variables_initializer())
           outputs = sess.run(output_op)
           # Make sure that there are no NaNs
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 0f037e2..29dede2 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -165,7 +165,7 @@
 
   def testGradientNoise(self):
     random_seed.set_random_seed(42)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, var, loss, global_step = _setup_model()
       train = optimizers_lib.optimize_loss(
           loss,
@@ -182,7 +182,7 @@
 
   def testGradientNoiseWithClipping(self):
     random_seed.set_random_seed(42)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, var, loss, global_step = _setup_model()
       train = optimizers_lib.optimize_loss(
           loss,
@@ -198,7 +198,7 @@
       self.assertEqual(global_step_value, 1)
 
   def testGradientClip(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, var, loss, global_step = _setup_model()
       train = optimizers_lib.optimize_loss(
           loss,
@@ -213,7 +213,7 @@
       self.assertEqual(global_step_value, 1)
 
   def testAdaptiveGradientClip(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, var, loss, global_step = _setup_model()
       clip_gradients = optimizers_lib.adaptive_clipping_fn()
       train = optimizers_lib.optimize_loss(
@@ -234,7 +234,7 @@
       self.assertEqual(2, var_count)
 
   def testGradientMultiply(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, var, loss, global_step = _setup_model()
       train = optimizers_lib.optimize_loss(
           loss,
@@ -433,7 +433,7 @@
 class AdaptiveClipping(test.TestCase):
 
   def testAverages(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       scale = 2.
       grad = array_ops.ones([3, 4]) * scale
       log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
@@ -463,7 +463,7 @@
       self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
 
   def testClip(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       spike = 1000.
       multiplier = array_ops.placeholder(dtypes.float32, [], "multiplier")
       step = array_ops.placeholder(dtypes.int32, [], "step")
diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py
index 07191ee..51faba3 100644
--- a/tensorflow/contrib/layers/python/layers/regularizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py
@@ -71,7 +71,7 @@
     with self.assertRaises(ValueError):
       regularizers.l1_l2_regularizer(0.5, 0)
 
-    with self.test_session():
+    with self.cached_session():
       shape = [5, 5, 5]
       num_elem = 5 * 5 * 5
       tensor = constant_op.constant(1.0, shape=shape)
@@ -84,7 +84,7 @@
     num_elem = 5 * 5 * 5
     tensor = constant_op.constant(1.0, shape=shape)
     loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
-    with self.test_session():
+    with self.cached_session():
       self.assertEquals(loss.op.name, 'l1_l2_regularizer')
       self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
 
@@ -93,7 +93,7 @@
     num_elem = 5 * 5 * 5
     tensor = constant_op.constant(1.0, shape=shape)
     loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
-    with self.test_session():
+    with self.cached_session():
       self.assertEquals(loss.op.name, 'l1_l2_regularizer')
       self.assertAlmostEqual(loss.eval(), num_elem, 5)
 
@@ -104,7 +104,7 @@
     self.assertEquals(loss, None)
 
   def testL1L2RegularizerWithScope(self):
-    with self.test_session():
+    with self.cached_session():
       shape = [5, 5, 5]
       num_elem = 5 * 5 * 5
       tensor = constant_op.constant(1.0, shape=shape)
@@ -142,7 +142,7 @@
     array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
     tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
     expected = sum([2 * x for l in array_weights_list for x in l])
-    with self.test_session():
+    with self.cached_session():
       result = regularizers.apply_regularization(dummy_regularizer,
                                                  tensor_weights_list)
       self.assertAllClose(expected, result.eval())
@@ -151,7 +151,7 @@
     regularizer = regularizers.l2_regularizer(0.0)
     array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
     tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
-    with self.test_session():
+    with self.cached_session():
       result = regularizers.apply_regularization(regularizer,
                                                  tensor_weights_list)
       self.assertAllClose(0.0, result.eval())
@@ -161,7 +161,7 @@
     tensor_weights_list = [
         constant_op.constant(x) for x in [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
     ]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         regularizers.apply_regularization(non_scalar_regularizer,
                                           tensor_weights_list)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index c34b5a8..2c7463a 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -58,7 +58,7 @@
     y1, y2 = block.forward(x1, x2)
     x1_inv, x2_inv = block.backward(y1, y2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv])
 
@@ -81,7 +81,7 @@
     x1, x2 = block.backward(y1, y2)
     y1_inv, y2_inv = block.forward(x1, x2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
 
@@ -151,7 +151,7 @@
     grads_rev = gradients_impl.gradients(loss_rev, wrt)
     grads = gradients_impl.gradients(loss, wrt)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
       self.assertAllClose(y_val, yd_val)
@@ -286,7 +286,7 @@
     for out, scope_vars in outputs_and_vars:
       all_grads.append(gradients_impl.gradients(out, scope_vars))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       outputs = list(zip(*outputs_and_vars))[0]
       outs, all_grads_val = sess.run([outputs, all_grads])
@@ -389,7 +389,7 @@
       layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
 
     grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(grads)
 
   def testErrorOnClosedOverTensor(self):
diff --git a/tensorflow/contrib/layers/python/layers/summaries_test.py b/tensorflow/contrib/layers/python/layers/summaries_test.py
index a1ef06f..2ec2af9 100644
--- a/tensorflow/contrib/layers/python/layers/summaries_test.py
+++ b/tensorflow/contrib/layers/python/layers/summaries_test.py
@@ -29,19 +29,19 @@
 class SummariesTest(test.TestCase):
 
   def test_summarize_scalar_tensor(self):
-    with self.test_session():
+    with self.cached_session():
       scalar_var = variables.Variable(1)
       summary_op = summaries_lib.summarize_tensor(scalar_var)
       self.assertEquals(summary_op.op.type, 'ScalarSummary')
 
   def test_summarize_multidim_tensor(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_var = variables.Variable([1, 2, 3])
       summary_op = summaries_lib.summarize_tensor(tensor_var)
       self.assertEquals(summary_op.op.type, 'HistogramSummary')
 
   def test_summarize_activation(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(1)
       op = array_ops.identity(var, name='SummaryTest')
       summary_op = summaries_lib.summarize_activation(op)
@@ -52,7 +52,7 @@
       self.assertIn(u'SummaryTest/activation', names)
 
   def test_summarize_activation_relu(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(1)
       op = nn_ops.relu(var, name='SummaryTest')
       summary_op = summaries_lib.summarize_activation(op)
@@ -64,7 +64,7 @@
       self.assertIn(u'SummaryTest/activation', names)
 
   def test_summarize_activation_relu6(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(1)
       op = nn_ops.relu6(var, name='SummaryTest')
       summary_op = summaries_lib.summarize_activation(op)
@@ -77,7 +77,7 @@
       self.assertIn(u'SummaryTest/activation', names)
 
   def test_summarize_collection_regex(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(1)
       array_ops.identity(var, name='Test1')
       ops.add_to_collection('foo', array_ops.identity(var, name='Test2'))
diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py
index a9bd895..34f63f5 100644
--- a/tensorflow/contrib/layers/python/layers/utils_test.py
+++ b/tensorflow/contrib/layers/python/layers/utils_test.py
@@ -42,7 +42,7 @@
       c = constant_op.constant(v)
       value = utils.constant_value(c)
       self.assertEqual(value, v)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(c.eval(), v)
 
   def test_variable(self):
@@ -60,7 +60,7 @@
       x = array_ops.identity(p)
       value = utils.constant_value(p)
       self.assertEqual(value, None)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(x.eval(feed_dict={p: v}), v)
 
 
@@ -80,7 +80,7 @@
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
       o = utils.static_cond(v, fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(), expected(v))
 
   def test_variable(self):
@@ -89,7 +89,7 @@
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
       o = utils.static_cond(v, fn1, fn2)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.global_variables_initializer())
         self.assertEqual(o.eval(), expected(v))
 
@@ -99,7 +99,7 @@
     expected = lambda v: -1 if v else -2
     for v in [True, False, 1, 0]:
       o = utils.static_cond(v, fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(), expected(v))
 
 
@@ -119,7 +119,7 @@
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(), expected(v))
 
   def test_variable(self):
@@ -128,7 +128,7 @@
     expected = lambda v: b'fn1' if v else b'fn2'
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.global_variables_initializer())
         self.assertEqual(o.eval(), expected(v))
 
@@ -138,7 +138,7 @@
     expected = lambda v: -1 if v else -2
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(), expected(v))
 
 
@@ -151,7 +151,7 @@
     p = array_ops.placeholder(dtypes.bool, [])
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(p, fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
 
   def test_constant(self):
@@ -161,7 +161,7 @@
     p = array_ops.placeholder(dtypes.bool, [])
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(p, fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
 
   def test_variable(self):
@@ -171,7 +171,7 @@
     p = array_ops.placeholder(dtypes.bool, [])
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(p, fn1, fn2)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.global_variables_initializer())
         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
 
@@ -182,7 +182,7 @@
     p = array_ops.placeholder(dtypes.bool, [])
     for v in [True, False, 1, 0]:
       o = utils.smart_cond(p, fn1, fn2)
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
 
 
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index d507500..b6c2cab 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -42,7 +42,7 @@
 class DenseToSparseTensorTest(test.TestCase):
 
   def test_dense_to_sparse_tensor_1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([1, 0, 2, 0])
       result = sess.run(st)
     self.assertEqual(result.indices.dtype, np.int64)
@@ -53,7 +53,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_1d_float(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([1.5, 0.0, 2.3, 0.0])
       result = sess.run(st)
     self.assertEqual(result.indices.dtype, np.int64)
@@ -64,7 +64,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_1d_bool(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([True, False, True, False])
       result = sess.run(st)
     self.assertEqual(result.indices.dtype, np.int64)
@@ -75,7 +75,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_1d_str(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b''])
       result = sess.run(st)
     self.assertEqual(result.indices.dtype, np.int64)
@@ -86,7 +86,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor(
           [b'qwe', b'', b'ewq', b''], ignore_value=b'qwe')
       result = sess.run(st)
@@ -98,7 +98,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([[1, 2, 0, 0], [3, 4, 5, 0]])
       result = sess.run(st)
     self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -107,7 +107,7 @@
     self.assertAllEqual([2, 4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_3d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor([[[1, 2, 0, 0], [3, 4, 5, 0]],
                                               [[7, 8, 0, 0], [9, 0, 0, 0]]])
       result = sess.run(st)
@@ -117,7 +117,7 @@
     self.assertAllEqual([2, 2, 4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_unknown_1d_shape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32)
       st = sparse_ops.dense_to_sparse_tensor(tensor)
       result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]})
@@ -126,7 +126,7 @@
     self.assertAllEqual([4], result.dense_shape)
 
   def test_dense_to_sparse_tensor_unknown_3d_shape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tensor = array_ops.placeholder(
           shape=[None, None, None], dtype=dtypes.int32)
       st = sparse_ops.dense_to_sparse_tensor(tensor)
@@ -142,7 +142,7 @@
 
   def test_dense_to_sparse_unknown_rank(self):
     ph = array_ops.placeholder(dtype=dtypes.int32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       st = sparse_ops.dense_to_sparse_tensor(ph)
       result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]})
     self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -155,7 +155,7 @@
 
   def test_sparse_row_envelope(self):
     expected_sparse_row_envelope = [1, 0, 3]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_input = sparse_tensor.SparseTensor(
           indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
           values=[0, 1, 2, 3],
@@ -167,7 +167,7 @@
 
   def test_sparse_row_envelope_unsorted_indices(self):
     expected_sparse_row_envelope = [1, 0, 3]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_input = sparse_tensor.SparseTensor(
           indices=[[2, 0], [2, 2], [2, 1], [0, 0]],
           values=[0, 1, 2, 3],
@@ -179,7 +179,7 @@
 
   def test_sparse_row_envelope_empty_in_the_end(self):
     expected_sparse_row_envelope = [1, 0, 3, 0, 0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_input = sparse_tensor.SparseTensor(
           indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
           values=[0, 1, 2, 3],
@@ -191,7 +191,7 @@
 
   def test_sparse_row_envelope_empty_3d(self):
     expected_sparse_row_envelope = [1, 0, 3, 0, 0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sparse_input = sparse_tensor.SparseTensor(
           indices=[[0, 0, 0], [0, 2, 0], [0, 2, 1], [0, 2, 2]],
           values=[0, 1, 2, 3],
@@ -207,7 +207,7 @@
   def test_indicators_to_sparse_ids_1d(self):
     indicators = (0, 0, 1, 0)
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0,),),
           values=(2,),
@@ -220,7 +220,7 @@
         (1, 0, 0, 1),
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0), (1, 0), (1, 1)),
           values=(2, 0, 3),
@@ -235,7 +235,7 @@
         ((1, 0, 0, 1, 1), (0, 0, 1, 0, 0)),
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=(
               (0, 0, 0),
@@ -255,7 +255,7 @@
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(
         indicators, dtype=dtypes.int16)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0), (1, 0), (1, 1)),
           values=np.array((2, 0, 3), dtype=np.int16),
@@ -269,7 +269,7 @@
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(
         indicators, ignore_value=-1)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
           values=(2, 0, 3, 2),
@@ -282,7 +282,7 @@
         (('B', '', '', 'C'), ('', '', 'D', '')),
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
           values=(2, 0, 3, 2),
@@ -296,7 +296,7 @@
     )
     sparse_ids = sparse_ops.indicators_to_sparse_ids(
         indicators, ignore_value='x')
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
           values=(2, 0, 3, 2),
@@ -311,7 +311,7 @@
     indicators = array_ops.placeholder(
         dtype=dtypes.int32, shape=(None, None, None))
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
           values=(2, 0, 3, 2),
@@ -325,7 +325,7 @@
     )
     indicators = array_ops.placeholder(dtype=dtypes.int32)
     sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
-    with self.test_session():
+    with self.cached_session():
       _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
           indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
           values=(2, 0, 3, 2),
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 418b0cf..61185f6 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -403,6 +403,7 @@
     srcs = ["python/learn/estimators/dnn_test.py"],
     shard_count = 4,
     srcs_version = "PY2AND3",
+    tags = ["notap"],
     deps = [
         ":learn",
         "//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 5e07b93..284a4f4 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -147,7 +147,7 @@
   def test_unsupervised(self):
 
     def func(feeder):
-      with self.test_session():
+      with self.cached_session():
         inp, _ = feeder.input_builder()
         feed_dict_fn = feeder.get_feed_dict_fn()
         feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@
   def test_epoch(self):
 
     def func(feeder):
-      with self.test_session():
+      with self.cached_session():
         feeder.input_builder()
         epoch = feeder.make_epoch_variable()
         feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b..5e90d1f 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@
             'label': np.ones(1) * index - 32
         }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator,
           target_key='label',
@@ -68,7 +68,7 @@
       for index in range(2):
         yield {'a': np.ones(1) * index}
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
       features = input_fn()
@@ -97,7 +97,7 @@
             'label2': np.ones(1) * index - 64,
         }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator,
           target_key=['label', 'label2'],
@@ -134,7 +134,7 @@
             'label': np.ones((3, 3)) * index - 32
         }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator,
           target_key='label',
@@ -162,7 +162,7 @@
 
   def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
     x = np.arange(32, 36)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
         failing_input_fn = generator_io.generator_input_fn(
             x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@
     def generator():
       return np.arange(32, 36)
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
         failing_input_fn = generator_io.generator_input_fn(
             generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@
     def generator():
       yield np.arange(32, 36)
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
         failing_input_fn = generator_io.generator_input_fn(
             generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@
         }
 
     y = np.arange(32, 36)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
                                    ' Container of str'):
         failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@
         }
 
     y = ['label', np.arange(10)]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
                                    ' Container of str'):
         failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@
         }
 
     y = ['label', 'target']
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
         failing_input_fn = generator_io.generator_input_fn(
             generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@
             'label': np.ones(1) * index - 32
         }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
       features = input_fn()
@@ -283,7 +283,7 @@
             'label': np.ones(1) * index - 32
         }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
       features = input_fn()
@@ -319,7 +319,7 @@
           'label': np.ones(1) * index - 32
       }
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = generator_io.generator_input_fn(
           generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
       features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e..396539a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@
   def testPandasInputFn_ProducesExpectedOutputs(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@
   def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       index = np.arange(100, 102)
       a = np.arange(2)
       b = np.arange(32, 34)
@@ -107,7 +107,7 @@
   def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       index = np.arange(100, 105)
       a = np.arange(5)
       b = np.arange(32, 37)
@@ -146,7 +146,7 @@
   def testPandasInputFn_OnlyX(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, _ = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@
   def testPandasInputFn_ExcludesIndex(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@
   def testPandasInputFn_RespectsEpoch_NoShuffle(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@
   def testPandasInputFn_RespectsEpoch_WithShuffle(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@
   def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@
     if not HAS_PANDAS:
       return
     x, y = self.makeTestDataFrame()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=3, shuffle=False, num_epochs=1)
 
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923..ff19011 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@
   """Ops tests."""
 
   def test_softmax_classifier(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       features = array_ops.placeholder(dtypes.float32, [None, 3])
       labels = array_ops.placeholder(dtypes.float32, [None, 2])
       weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@
     ids_shape = (2, 3, 4)
     embeds = np.random.randn(n_embed, d_embed)
     ids = np.random.randint(0, n_embed, ids_shape)
-    with self.test_session():
+    with self.cached_session():
       embed_np = embeds[ids]
       embed_tf = ops.embedding_lookup(embeds, ids).eval()
     self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@
 
   def test_categorical_variable(self):
     random_seed.set_random_seed(42)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
       embeddings = ops.categorical_variable(
           cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61..5a7e4eb 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@
   """Sequence-to-sequence tests."""
 
   def test_sequence_classifier(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       decoding = [
           array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
       ]
@@ -60,7 +60,7 @@
   def test_seq2seq_inputs(self):
     inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
     out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
       y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
       in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@
                                   [[0, 0, 0], [0, 0, 0]]])
 
   def test_rnn_decoder(self):
-    with self.test_session():
+    with self.cached_session():
       decoder_inputs = [
           array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
       ]
diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
index 423dcce..8390ddd 100644
--- a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
+++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
@@ -29,7 +29,7 @@
 class DecodeLibsvmOpTest(test.TestCase):
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       content = [
           "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
           "2 3:2.5 2:nan 1:0.105"
@@ -48,7 +48,7 @@
                      [0, 0.105, np.nan, 2.5, 0, 0]])
 
   def testNDimension(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
                  ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
                  ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index a2d82cf..553b116 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -30,7 +30,7 @@
 
   def testShardedMutableHashTable(self):
     for num_shards in [1, 3, 10]:
-      with self.test_session():
+      with self.cached_session():
         default_val = -1
         empty_key = 0
         keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -53,7 +53,7 @@
 
   def testShardedMutableHashTableVectors(self):
     for num_shards in [1, 3, 10]:
-      with self.test_session():
+      with self.cached_session():
         default_val = [-0.1, 0.2]
         empty_key = [0, 1]
         keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
@@ -79,7 +79,7 @@
                             output.eval())
 
   def testExportSharded(self):
-    with self.test_session():
+    with self.cached_session():
       empty_key = -2
       default_val = -1
       num_shards = 2
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
index 237a681..51c4f68 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -36,13 +36,13 @@
     self.assertTrue(isinstance(sfc.example_indices, ops.Tensor))
     self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor))
     self.assertEqual(sfc.feature_values, None)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
       self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
     expected_feature_values = [1.0, 2.0, 3.0, 4.0]
     sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
                               expected_feature_values)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
 
 
diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp
index a960d85..3e7999a 100644
--- a/tensorflow/contrib/lite/Android.bp
+++ b/tensorflow/contrib/lite/Android.bp
@@ -31,7 +31,7 @@
 cc_library_static {
     name: "libtflite_context",
     defaults: ["tflite_defaults"],
-    srcs: ["context.c"],
+    srcs: ["c/c_api_internal.c"],
     cflags: [
         "-Wno-typedef-redefinition",
         "-Wno-visibility",
@@ -45,15 +45,18 @@
     srcs: [
         "allocation.cc",
         "arena_planner.cc",
-        "error_reporter.cc",
+        "core/api/error_reporter.cc",
+        "core/api/flatbuffer_conversions.cc",
+        "core/api/op_resolver.cc",
         "graph_info.cc",
         "interpreter.cc",
         "mmap_allocation.cc",
         "model.cc",
-        "op_resolver.cc",
+        "mutable_op_resolver.cc",
         "nnapi_delegate.cc",
         "optional_debug_tools.cc",
         "simple_memory_arena.cc",
+        "stderr_reporter.cc",
         "string_util.cc",
         "util.cc",
         "kernels/eigen_support.cc",
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 0091587..f320b53 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -36,10 +36,10 @@
     srcs = ["arena_planner.cc"],
     hdrs = ["arena_planner.h"],
     deps = [
-        ":context",
         ":graph_info",
         ":memory_planner",
         ":simple_memory_arena",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ],
 )
 
@@ -54,6 +54,7 @@
     deps = [
         ":arena_planner",
         "//tensorflow/contrib/lite/testing:util",
+        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "@com_google_googletest//:gtest",
     ],
@@ -63,27 +64,27 @@
 # TODO(aselle): Resolve problems preventing C99 usage.
 cc_library(
     name = "context",
-    srcs = ["context.c"],
     hdrs = ["context.h"],
+    deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
 )
 
 cc_library(
     name = "graph_info",
     hdrs = ["graph_info.h"],
-    deps = [":context"],
+    deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
 )
 
 cc_library(
     name = "memory_planner",
     hdrs = ["memory_planner.h"],
-    deps = [":context"],
+    deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
 )
 
 cc_library(
     name = "simple_memory_arena",
     srcs = ["simple_memory_arena.cc"],
     hdrs = ["simple_memory_arena.h"],
-    deps = [":context"],
+    deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
 )
 
 cc_library(
@@ -91,7 +92,7 @@
     hdrs = [
         "builtin_op_data.h",
     ],
-    deps = [":context"],
+    deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
 )
 
 cc_library(
@@ -121,12 +122,12 @@
     name = "framework",
     srcs = [
         "allocation.cc",
-        "error_reporter.cc",
         "graph_info.cc",
         "interpreter.cc",
         "model.cc",
-        "op_resolver.cc",
+        "mutable_op_resolver.cc",
         "optional_debug_tools.cc",
+        "stderr_reporter.cc",
     ] + select({
         "//tensorflow:android": [
             "nnapi_delegate.cc",
@@ -149,9 +150,11 @@
         "graph_info.h",
         "interpreter.h",
         "model.h",
+        "mutable_op_resolver.h",
         "nnapi_delegate.h",
         "op_resolver.h",
         "optional_debug_tools.h",
+        "stderr_reporter.h",
     ],
     copts = tflite_copts(),
     linkopts = [
@@ -164,14 +167,14 @@
     }),
     deps = [
         ":arena_planner",
-        ":builtin_op_data",
-        ":context",
         ":graph_info",
         ":memory_planner",
         ":schema_fbs_version",
         ":simple_memory_arena",
         ":string",
         ":util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+        "//tensorflow/contrib/lite/core/api",
         "//tensorflow/contrib/lite/kernels:eigen_support",
         "//tensorflow/contrib/lite/kernels:gemm_support",
         "//tensorflow/contrib/lite/nnapi:nnapi_lib",
@@ -210,6 +213,8 @@
     deps = [
         ":framework",
         ":string_util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+        "//tensorflow/contrib/lite/core/api",
         "//tensorflow/contrib/lite/kernels:builtin_ops",
         "//tensorflow/contrib/lite/kernels:kernel_util",
         "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
@@ -259,6 +264,8 @@
     ],
     deps = [
         ":framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+        "//tensorflow/contrib/lite/core/api",
         "//tensorflow/contrib/lite/testing:util",
         "@com_google_googletest//:gtest",
     ],
@@ -266,9 +273,9 @@
 
 # Test OpResolver.
 cc_test(
-    name = "op_resolver_test",
+    name = "mutable_op_resolver_test",
     size = "small",
-    srcs = ["op_resolver_test.cc"],
+    srcs = ["mutable_op_resolver_test.cc"],
     tags = ["no_oss"],
     deps = [
         ":framework",
@@ -277,24 +284,12 @@
     ],
 )
 
-# Test the C extension API code.
-cc_test(
-    name = "context_test",
-    size = "small",
-    srcs = ["context_test.cc"],
-    deps = [
-        ":framework",
-        "//tensorflow/contrib/lite/testing:util",
-        "@com_google_googletest//:gtest",
-    ],
-)
-
 cc_library(
     name = "util",
     srcs = ["util.cc"],
     hdrs = ["util.h"],
     deps = [
-        ":context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ],
 )
 
@@ -304,7 +299,6 @@
     srcs = ["util_test.cc"],
     tags = ["no_oss"],
     deps = [
-        ":context",
         ":util",
         "//tensorflow/contrib/lite/testing:util",
         "@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md
deleted file mode 100644
index 8fd63d5..0000000
--- a/tensorflow/contrib/lite/RELEASE.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Release 0.1.7
-
-* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit
-  fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0).
-* To reproduce the iOS library, it's required to cherry pick git commit
-  f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue.
-* The code is based on TensorFlow 1.8.0 release candidate and it's very close
-  to TensorFlow 1.8.0 release.
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index 8946261..21cb183 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -23,8 +23,8 @@
 #include <cstring>
 #include <utility>
 
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 6ffdb41..f0dcffe 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -20,8 +20,8 @@
 #include <cstdio>
 #include <cstdlib>
 #include <vector>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/simple_memory_arena.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 55003cf..3825770 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -18,7 +18,7 @@
 #include <memory>
 #include <vector>
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/graph_info.h"
 #include "tensorflow/contrib/lite/memory_planner.h"
 #include "tensorflow/contrib/lite/simple_memory_arena.h"
@@ -37,8 +37,8 @@
 // each tensor needs to be allocated and deallocated, and preallocates all the
 // necessary memory (the PlanAllocations phase). It then assigns portions of
 // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
-// share some of the buffer if a tensor B is to be allocated after another tensor
-// A has been deallocated.
+// share some of the buffer if a tensor B is to be allocated after another
+// tensor A has been deallocated.
 //
 // If dynamic tensors are used the planning steps can be repeated during model
 // execution. Since dynamic tensors don't have sizes until after the
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index fc199f0..5c705ea 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -49,6 +49,9 @@
     Returns:
        a select object with proper linkopts
     """
+
+    # In case you wonder why there's no --icf is because the gains were
+    # negligible, and created potential compatibility problems.
     return select({
         "//tensorflow:android": [
             "-Wl,--no-export-dynamic",  # Only inc syms referenced by dynamic obj.
@@ -56,12 +59,7 @@
             "-Wl,--gc-sections",  # Eliminate unused code and data.
             "-Wl,--as-needed",  # Don't link unused libs.
         ],
-        "//tensorflow:darwin": [],
-        "//tensorflow/contrib/lite:mips": [],
-        "//tensorflow/contrib/lite:mips64": [],
-        "//conditions:default": [
-            "-Wl,--icf=all",  # Identical code folding.
-        ],
+        "//conditions:default": [],
     })
 
 def tflite_jni_linkopts_unstripped():
@@ -73,17 +71,15 @@
     Returns:
        a select object with proper linkopts
     """
+
+    # In case you wonder why there's no --icf is because the gains were
+    # negligible, and created potential compatibility problems.
     return select({
         "//tensorflow:android": [
             "-Wl,--gc-sections",  # Eliminate unused code and data.
             "-Wl,--as-needed",  # Don't link unused libs.
         ],
-        "//tensorflow:darwin": [],
-        "//tensorflow/contrib/lite:mips": [],
-        "//tensorflow/contrib/lite:mips64": [],
-        "//conditions:default": [
-            "-Wl,--icf=all",  # Identical code folding.
-        ],
+        "//conditions:default": [],
     })
 
 def tflite_linkopts():
@@ -287,6 +283,7 @@
         "sparse_to_dense",
         "split",
         "sqrt",
+        "square",
         "squeeze",
         "strided_slice",
         "strided_slice_1d_exhaustive",
@@ -299,32 +296,74 @@
         "where",
     ]
 
-def gen_zip_test(name, test_name, **kwargs):
+def generated_test_conversion_modes():
+    """Returns a list of conversion modes."""
+
+    # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
+    return ["toco-extended", ""]
+
+def generated_test_models_all():
+    """Generates a list of all tests with the different converters.
+
+    Returns:
+      List of tuples representing (conversion mode, name of test).
+    """
+    conversion_modes = generated_test_conversion_modes()
+    tests = generated_test_models()
+    options = []
+    for conversion_mode in conversion_modes:
+        for test in tests:
+            if conversion_mode:
+                test += "_%s" % conversion_mode
+            options.append((conversion_mode, test))
+    return options
+
+def gen_zip_test(name, test_name, conversion_mode, **kwargs):
     """Generate a zipped-example test and its dependent zip files.
 
     Args:
-      name: Resulting cc_test target name
-      test_name: Test targets this model. Comes from the list above.
-      **kwargs: tf_cc_test kwargs.
+      name: str. Resulting cc_test target name
+      test_name: str. Test targets this model. Comes from the list above.
+      conversion_mode: str. Which conversion mode to run with. Comes from the
+        list above.
+      **kwargs: tf_cc_test kwargs
     """
+    toco = "//tensorflow/contrib/lite/toco:toco"
+    flags = ""
+    if conversion_mode:
+        # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
+        # if conversion_mode == "pb2lite":
+        #     toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite"
+        flags = "--ignore_toco_errors --run_with_extended"
+        kwargs["tags"].append("skip_already_failing")
+        kwargs["tags"].append("no_oss")
+
+        # TODO(b/115504899): Re-enable asan, msan and tsan tests.
+        kwargs["tags"].append("noasan")
+        kwargs["tags"].append("nomsan")
+        kwargs["tags"].append("notsan")
+
     gen_zipped_test_file(
         name = "zip_%s" % test_name,
         file = "%s.zip" % test_name,
+        toco = toco,
+        flags = flags,
     )
     tf_cc_test(name, **kwargs)
 
-def gen_zipped_test_file(name, file):
+def gen_zipped_test_file(name, file, toco, flags):
     """Generate a zip file of tests by using :generate_examples.
 
     Args:
-      name: Name of output. We will produce "`file`.files" as a target.
-      file: The name of one of the generated_examples targets, e.g. "transpose"
+      name: str. Name of output. We will produce "`file`.files" as a target.
+      file: str. The name of one of the generated_examples targets, e.g. "transpose"
+      toco: str. Pathname of toco binary to run
+      flags: str. Any additional flags to include
     """
-    toco = "//tensorflow/contrib/lite/toco:toco"
     native.genrule(
         name = file + ".files",
-        cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco +
-               " --zip_to_output " + file + " $(@D)"),
+        cmd = (("$(locations :generate_examples) --toco $(locations {0}) " +
+                " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)),
         outs = [file],
         tools = [
             ":generate_examples",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index e81f9e4..30901bd 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,287 +12,11 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
 #ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
 #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
 
-#include <stdint.h>
-
-#include "tensorflow/contrib/lite/context.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-// TODO(aselle): Consider using "if this then that" for testing.
-
-// Possible padding types (for convolutions)
-typedef enum {
-  kTfLitePaddingUnknown = 0,
-  kTfLitePaddingSame,
-  kTfLitePaddingValid,
-} TfLitePadding;
-
-typedef struct {
-  int width;
-  int height;
-} TfLitePaddingValues;
-
-// Possible fused activation functions.
-// TODO(aselle): rename to TfLiteActivation
-typedef enum {
-  kTfLiteActNone = 0,
-  kTfLiteActRelu,
-  kTfLiteActRelu1,
-  kTfLiteActRelu6,
-  kTfLiteActTanh,
-  kTfLiteActSignBit,
-  kTfLiteActSigmoid,
-} TfLiteFusedActivation;
-
-typedef struct {
-  TfLitePadding padding;
-  int stride_width;
-  int stride_height;
-  int dilation_width_factor;
-  int dilation_height_factor;
-  TfLiteFusedActivation activation;
-} TfLiteConvParams;
-
-typedef struct {
-  TfLitePadding padding;
-  int stride_width;
-  int stride_height;
-  int filter_width;
-  int filter_height;
-  TfLiteFusedActivation activation;
-  struct {
-    TfLitePaddingValues padding;
-  } computed;
-} TfLitePoolParams;
-
-typedef struct {
-  TfLitePadding padding;
-  int stride_width;
-  int stride_height;
-  int depth_multiplier;
-  TfLiteFusedActivation activation;
-} TfLiteDepthwiseConvParams;
-
-typedef struct {
-  int rank;
-  TfLiteFusedActivation activation;
-} TfLiteSVDFParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteRNNParams;
-
-typedef struct {
-  bool time_major;
-  TfLiteFusedActivation activation;
-} TfLiteSequenceRNNParams;
-
-typedef enum {
-  kTfLiteFullyConnectedWeightsFormatDefault = 0,
-  kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
-} TfLiteFullyConnectedWeightsFormat;
-
-typedef struct {
-  // Parameters for FullyConnected version 1 or above.
-  TfLiteFusedActivation activation;
-
-  // Parameters for FullyConnected version 2 or above.
-  TfLiteFullyConnectedWeightsFormat weights_format;
-} TfLiteFullyConnectedParams;
-
-typedef enum {
-  kTfLiteLshProjectionUnknown = 0,
-  kTfLiteLshProjectionSparse = 1,
-  kTfLiteLshProjectionDense = 2,
-} TfLiteLSHProjectionType;
-
-typedef struct {
-  TfLiteLSHProjectionType type;
-} TfLiteLSHProjectionParams;
-
-typedef struct {
-  float beta;
-} TfLiteSoftmaxParams;
-
-typedef struct {
-  int axis;
-  TfLiteFusedActivation activation;
-} TfLiteConcatenationParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteAddParams;
-
-typedef struct {
-} TfLiteSpaceToBatchNDParams;
-
-typedef struct {
-} TfLiteBatchToSpaceNDParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteMulParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteSubParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteDivParams;
-
-typedef struct {
-  TfLiteFusedActivation activation;
-} TfLiteL2NormParams;
-
-typedef struct {
-  int radius;
-  float bias;
-  float alpha;
-  float beta;
-} TfLiteLocalResponseNormParams;
-
-typedef enum {
-  kTfLiteLSTMFullKernel = 0,
-  kTfLiteLSTMBasicKernel
-} TfLiteLSTMKernelType;
-
-typedef struct {
-  // Parameters for LSTM version 1.
-  TfLiteFusedActivation activation;
-  float cell_clip;
-  float proj_clip;
-
-  // Parameters for LSTM version 2.
-  // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
-  TfLiteLSTMKernelType kernel_type;
-} TfLiteLSTMParams;
-
-typedef struct {
-  bool align_corners;
-} TfLiteResizeBilinearParams;
-
-typedef struct {
-} TfLitePadParams;
-
-typedef struct {
-} TfLitePadV2Params;
-
-typedef struct {
-  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
-  // For now we will fix the maximum possible number of dimensions.
-  int shape[8];
-  int num_dimensions;
-} TfLiteReshapeParams;
-
-typedef struct {
-  int ngram_size;
-  int max_skip_size;
-  bool include_all_ngrams;
-} TfLiteSkipGramParams;
-
-typedef struct {
-  int block_size;
-} TfLiteSpaceToDepthParams;
-
-typedef struct {
-  TfLiteType in_data_type;
-  TfLiteType out_data_type;
-} TfLiteCastParams;
-
-typedef enum {
-  kTfLiteCombinerTypeSum = 0,
-  kTfLiteCombinerTypeMean = 1,
-  kTfLiteCombinerTypeSqrtn = 2,
-} TfLiteCombinerType;
-
-typedef struct {
-  TfLiteCombinerType combiner;
-} TfLiteEmbeddingLookupSparseParams;
-
-typedef struct {
-  int axis;
-} TfLiteGatherParams;
-
-typedef struct {
-} TfLiteTransposeParams;
-
-typedef struct {
-  bool keep_dims;
-} TfLiteReducerParams;
-
-typedef struct {
-  int num_splits;
-} TfLiteSplitParams;
-
-typedef struct {
-  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
-  // For now we will fix the maximum possible number of dimensions.
-  int squeeze_dims[8];
-  int num_squeeze_dims;
-} TfLiteSqueezeParams;
-
-typedef struct {
-  int begin_mask;
-  int end_mask;
-  int ellipsis_mask;
-  int new_axis_mask;
-  int shrink_axis_mask;
-} TfLiteStridedSliceParams;
-
-typedef struct {
-  TfLiteType output_type;
-} TfLiteArgMaxParams;
-
-typedef struct {
-  TfLiteType output_type;
-} TfLiteArgMinParams;
-
-typedef struct {
-  TfLitePadding padding;
-  int stride_width;
-  int stride_height;
-} TfLiteTransposeConvParams;
-
-typedef struct {
-  bool validate_indices;
-} TfLiteSparseToDenseParams;
-
-typedef struct {
-  TfLiteType out_type;
-} TfLiteShapeParams;
-
-typedef struct {
-  // Parameters supported by version 1:
-  float min;
-  float max;
-  int num_bits;
-
-  // Parameters supported by version 2:
-  bool narrow_range;
-} TfLiteFakeQuantParams;
-
-typedef struct {
-  int values_count;
-  int axis;
-} TfLitePackParams;
-
-typedef struct {
-  int axis;
-} TfLiteOneHotParams;
-
-typedef struct {
-  int num;
-  int axis;
-} TfLiteUnpackParams;
-
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 #endif  // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 9cf4bea..5e97b77 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -117,6 +117,7 @@
   kTfLiteBuiltinReduceMin = 89,
   kTfLiteBuiltinFloorDiv = 90,
   kTfLiteBuiltinReduceAny = 91,
+  kTfLiteBuiltinSquare = 92,
 } TfLiteBuiltinOperator;
 
 #ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000..663eb63
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+cc_library(
+    name = "c_api_internal",
+    srcs = ["c_api_internal.c"],
+    hdrs = [
+        "builtin_op_data.h",
+        "c_api_internal.h",
+    ],
+    visibility = [
+        "//tensorflow/contrib/lite:__subpackages__",
+    ],
+)
+
+# Test the C extension API code.
+cc_test(
+    name = "c_api_internal_test",
+    size = "small",
+    srcs = ["c_api_internal_test.cc"],
+    deps = [
+        ":c_api_internal",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_test(
+    name = "builtin_op_data_test",
+    size = "small",
+    srcs = ["builtin_op_data_test.cc"],
+    copts = ["-Wno-unused-variable"],
+    deps = [
+        ":c_api_internal",
+        "@com_google_googletest//:gtest",
+    ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000..be9d551
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,305 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// IMPORTANT: All new members of structs must be added at the end to ensure
+// backwards compatibility.
+
+// Possible padding types (for convolutions)
+typedef enum {
+  kTfLitePaddingUnknown = 0,
+  kTfLitePaddingSame,
+  kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+  int width;
+  int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+  kTfLiteActNone = 0,
+  kTfLiteActRelu,
+  kTfLiteActRelu1,
+  kTfLiteActRelu6,
+  kTfLiteActTanh,
+  kTfLiteActSignBit,
+  kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+  TfLitePadding padding;
+  int stride_width;
+  int stride_height;
+  int dilation_width_factor;
+  int dilation_height_factor;
+  TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+  TfLitePadding padding;
+  int stride_width;
+  int stride_height;
+  int filter_width;
+  int filter_height;
+  TfLiteFusedActivation activation;
+  struct {
+    TfLitePaddingValues padding;
+  } computed;
+} TfLitePoolParams;
+
+typedef struct {
+  // Parameters for DepthwiseConv version 1 or above.
+  TfLitePadding padding;
+  int stride_width;
+  int stride_height;
+  int depth_multiplier;
+  TfLiteFusedActivation activation;
+  // Parameters for DepthwiseConv version 2 or above.
+  int dilation_width_factor;
+  int dilation_height_factor;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+  int rank;
+  TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+  bool time_major;
+  TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+  kTfLiteFullyConnectedWeightsFormatDefault = 0,
+  kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+  // Parameters for FullyConnected version 1 or above.
+  TfLiteFusedActivation activation;
+
+  // Parameters for FullyConnected version 2 or above.
+  TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+  kTfLiteLshProjectionUnknown = 0,
+  kTfLiteLshProjectionSparse = 1,
+  kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+  TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+  float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+  int axis;
+  TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+  TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+  int radius;
+  float bias;
+  float alpha;
+  float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+  kTfLiteLSTMFullKernel = 0,
+  kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+  // Parameters for LSTM version 1.
+  TfLiteFusedActivation activation;
+  float cell_clip;
+  float proj_clip;
+
+  // Parameters for LSTM version 2.
+  // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+  TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+  bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+  // For now we will fix the maximum possible number of dimensions.
+  int shape[8];
+  int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+  int ngram_size;
+  int max_skip_size;
+  bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+  int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+  TfLiteType in_data_type;
+  TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+  kTfLiteCombinerTypeSum = 0,
+  kTfLiteCombinerTypeMean = 1,
+  kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+  TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+  int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+  bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+  int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+  // For now we will fix the maximum possible number of dimensions.
+  int squeeze_dims[8];
+  int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+  int begin_mask;
+  int end_mask;
+  int ellipsis_mask;
+  int new_axis_mask;
+  int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+  TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+  TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+  TfLitePadding padding;
+  int stride_width;
+  int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+  bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+  TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+  // Parameters supported by version 1:
+  float min;
+  float max;
+  int num_bits;
+
+  // Parameters supported by version 2:
+  bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+  int values_count;
+  int axis;
+} TfLitePackParams;
+
+typedef struct {
+  int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+  int num;
+  int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+
+#endif  // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000..4d0ba75
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+  TfLitePadding padding = kTfLitePaddingSame;
+  TfLitePaddingValues padding_values;
+  TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+  TfLiteConvParams conv_params;
+  TfLitePoolParams pool_params;
+  TfLiteDepthwiseConvParams depthwise_conv_params;
+  TfLiteSVDFParams svdf_params;
+  TfLiteRNNParams rnn_params;
+  TfLiteSequenceRNNParams sequence_rnn_params;
+  TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+      kTfLiteFullyConnectedWeightsFormatDefault;
+  TfLiteFullyConnectedParams fully_connected_params;
+  TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+  TfLiteLSHProjectionParams projection_params;
+  TfLiteSoftmaxParams softmax_params;
+  TfLiteConcatenationParams concatenation_params;
+  TfLiteAddParams add_params;
+  TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+  TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+  TfLiteMulParams mul_params;
+  TfLiteSubParams sub_params;
+  TfLiteDivParams div_params;
+  TfLiteL2NormParams l2_norm_params;
+  TfLiteLocalResponseNormParams local_response_norm_params;
+  TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+  TfLiteLSTMParams lstm_params;
+  TfLiteResizeBilinearParams resize_bilinear_params;
+  TfLitePadParams pad_params;
+  TfLitePadV2Params pad_v2_params;
+  TfLiteReshapeParams reshape_params;
+  TfLiteSkipGramParams skip_gram_params;
+  TfLiteSpaceToDepthParams space_to_depth_params;
+  TfLiteCastParams cast_params;
+  TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+  TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+  TfLiteGatherParams gather_params;
+  TfLiteTransposeParams transpose_params;
+  TfLiteReducerParams reducer_params;
+  TfLiteSplitParams split_params;
+  TfLiteSqueezeParams squeeze_params;
+  TfLiteStridedSliceParams strided_slice_params;
+  TfLiteArgMaxParams arg_max_params;
+  TfLiteArgMinParams arg_min_params;
+  TfLiteTransposeConvParams transpose_conv_params;
+  TfLiteSparseToDenseParams sparse_to_dense_params;
+  TfLiteShapeParams shape_params;
+  TfLiteFakeQuantParams fake_quant_params;
+  TfLitePackParams pack_params;
+  TfLiteOneHotParams one_hot_params;
+}
+
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c
similarity index 95%
rename from tensorflow/contrib/lite/context.c
rename to tensorflow/contrib/lite/c/c_api_internal.c
index 7f2aa31..1846bad 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -13,8 +13,9 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include <stdio.h>
+#include <stdlib.h>
 #include <string.h>
 
 int TfLiteIntArrayGetSizeInBytes(int size) {
@@ -76,7 +77,8 @@
 void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
                        TfLiteQuantizationParams quantization, char* buffer,
                        size_t size, TfLiteAllocationType allocation_type,
-                       const void* allocation, bool is_variable, TfLiteTensor* tensor) {
+                       const void* allocation, bool is_variable,
+                       TfLiteTensor* tensor) {
   TfLiteTensorFree(tensor);
   tensor->type = type;
   tensor->name = name;
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000..ee3dff6
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,496 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif  // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+  kTfLiteEigenContext = 0,     // include eigen_support.h to use.
+  kTfLiteGemmLowpContext = 1,  // include gemm_support.h to use.
+  kTfLiteEdgeTpuContext = 2,   // Placeholder for Edge TPU support.
+  kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+  TfLiteExternalContextType type;
+  TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+  int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+    __GNUC_MINOR__ >= 1
+  int data[0];
+#else
+  int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg)            \
+  do {                                                     \
+    if (!(value)) {                                        \
+      (context)->ReportError((context), __FILE__ " " msg); \
+      return kTfLiteError;                                 \
+    }                                                      \
+  } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a)                                          \
+  do {                                                                      \
+    if (!(a)) {                                                             \
+      (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+                             __LINE__, #a);                                 \
+      return kTfLiteError;                                                  \
+    }                                                                       \
+  } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+  do {                           \
+    if ((a) != kTfLiteOk) {      \
+      return kTfLiteError;       \
+    }                            \
+  } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b)                                       \
+  do {                                                                         \
+    if ((a) != (b)) {                                                          \
+      (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+                             __LINE__, #a, #b, (a), (b));                      \
+      return kTfLiteError;                                                     \
+    }                                                                          \
+  } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+  do {                                     \
+    if ((status) != kTfLiteOk) {           \
+      return kTfLiteError;                 \
+    }                                      \
+  } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+  float re, im;  // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+  kTfLiteNoType = 0,
+  kTfLiteFloat32 = 1,
+  kTfLiteInt32 = 2,
+  kTfLiteUInt8 = 3,
+  kTfLiteInt64 = 4,
+  kTfLiteString = 5,
+  kTfLiteBool = 6,
+  kTfLiteInt16 = 7,
+  kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+//    real_value = scale * (quantized_value - zero_point);
+typedef struct {
+  float scale;
+  int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+  int* i32;
+  int64_t* i64;
+  float* f;
+  char* raw;
+  const char* raw_const;
+  uint8_t* uint8;
+  bool* b;
+  int16_t* i16;
+  TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+  kTfLiteMemNone = 0,
+  kTfLiteMmapRo,
+  kTfLiteArenaRw,
+  kTfLiteArenaRwPersistent,
+  kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+  // The data type specification for data stored in `data`. This affects
+  // what member of `data` union should be used.
+  TfLiteType type;
+  // A union of data pointers. The appropriate type should be used for a typed
+  // tensor based on `type`.
+  TfLitePtrUnion data;
+  // A pointer to a structure representing the dimensionality interpretation
+  // that the buffer should have. NOTE: the product of elements of `dims`
+  // and the element datatype size should be equal to `bytes` below.
+  TfLiteIntArray* dims;
+  // Quantization information.
+  TfLiteQuantizationParams params;
+  // How memory is mapped
+  //  kTfLiteMmapRo: Memory mapped read only.
+  //  i.e. weights
+  //  kTfLiteArenaRw: Arena allocated read write memory
+  //  (i.e. temporaries, outputs).
+  TfLiteAllocationType allocation_type;
+  // The number of bytes required to store the data of this Tensor. I.e.
+  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
+  // type is kTfLiteFloat32 and dims = {3, 2} then
+  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+  size_t bytes;
+
+  // An opaque pointer to a tflite::MMapAllocation
+  const void* allocation;
+
+  // Null-terminated name of this tensor.
+  const char* name;
+
+  // The delegate which knows how to handle `buffer_handle`.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteDelegate* delegate;
+
+  // An integer buffer handle that can be handled by `delegate`.
+  // The value is valid only when delegate is not null.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteBufferHandle buffer_handle;
+
+  // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+  // responsible to set data_is_stale to true.
+  // `delegate->CopyFromBufferHandle` can be called to copy the data from
+  // delegate buffer.
+  // WARNING: This is an // experimental interface that is subject to change.
+  bool data_is_stale;
+
+  // True if the tensor is a variable.
+  bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+                       TfLiteQuantizationParams quantization, char* buffer,
+                       size_t size, TfLiteAllocationType allocation_type,
+                       const void* allocation, bool is_variable,
+                       TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+  // Inputs to this node expressed as indices into the simulator's tensors.
+  TfLiteIntArray* inputs;
+
+  // Outputs to this node expressed as indices into the simulator's tensors.
+  TfLiteIntArray* outputs;
+
+  // Temporary tensors uses during the computations. This usually contains no
+  // tensors, but ops are allowed to change that if they need scratch space of
+  // any sort.
+  TfLiteIntArray* temporaries;
+
+  // Opaque data provided by the node implementer through `Registration.init`.
+  void* user_data;
+
+  // Opaque data provided to the node if the node is a builtin. This is usually
+  // a structure defined in builtin_op_data.h
+  void* builtin_data;
+
+  // Custom initial data. This is the opaque data provided in the flatbuffer.
+  // WARNING: This is an experimental interface that is subject to change.
+  const void* custom_initial_data;
+  int custom_initial_data_size;
+
+  // The pointer to the delegate. This is non-null only when the node is
+  // created by calling `interpreter.ModifyGraphWithDelegate`.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+  // Number of tensors in the context.
+  size_t tensors_size;
+
+  // The execution plan contains a list of the node indices in execution
+  // order. execution_plan->size is the current number of nodes. And,
+  // execution_plan->data[0] is the first node that needs to be run.
+  // TfLiteDelegates can traverse the current execution plan by iterating
+  // through each member of this array and using GetNodeAndRegistration() to
+  // access details about a node. i.e.
+  // TfLiteIntArray* execution_plan;
+  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+  // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+  //    int node_index = execution_plan->data[exec_index];
+  //    TfLiteNode* node;
+  //    TfLiteRegistration* reg;
+  //    context->GetNodeAndRegistration(context, node_index, &node, &reg);
+  // }
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+                                   TfLiteIntArray** execution_plan);
+
+  // An array of tensors in the interpreter context (of length `tensors_size`)
+  TfLiteTensor* tensors;
+
+  // opaque full context ptr (an opaque c++ data structure)
+  void* impl_;
+
+  // Request memory pointer be resized. Updates dimensions on the tensor.
+  // NOTE: ResizeTensor takes ownership of newSize.
+  TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+                               TfLiteIntArray* new_size);
+  // Request that a error be reported with format string msg.
+  void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+  // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
+  // non-null, the value pointed to by `first_new_tensor_index` will be set to
+  // the index of the first new tensor.
+  TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+                             int* first_new_tensor_index);
+
+  // Get a Tensor node by node_index.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+                                         TfLiteNode** node,
+                                         TfLiteRegistration** registration);
+
+  // Replace ops with one or more stub delegate operations. This function
+  // does not take ownership of `nodes_to_replace`.
+  TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+      struct TfLiteContext*, TfLiteRegistration registration,
+      const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+  // Number of threads that are recommended to subsystems like gemmlowp and
+  // eigen.
+  int recommended_num_threads;
+
+  // Access external contexts by type.
+  // WARNING: This is an experimental interface that is subject to change.
+  TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+                                               TfLiteExternalContextType);
+  // Set the value of a external context. Does not take ownership of the
+  // pointer.
+  // WARNING: This is an experimental interface that is subject to change.
+  void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+                             TfLiteExternalContext*);
+
+  // Flag for allowing float16 precision for FP32 calculation.
+  // default: false.
+  // WARNING: This is an experimental API and subject to change.
+  bool allow_fp32_relax_to_fp16;
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+  // Initializes the op from serialized data.
+  // If a built-in op:
+  //   `buffer` is the op's params data (TfLiteLSTMParams*).
+  //   `length` is zero.
+  // If custom op:
+  //   `buffer` is the op's `custom_options`.
+  //   `length` is the size of the buffer.
+  //
+  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+  // or an instance of a struct).
+  //
+  // The returned pointer will be stored with the node in the `user_data` field,
+  // accessible within prepare and invoke functions below.
+  // NOTE: if the data is already in the desired format, simply implement this
+  // function to return `nullptr` and implement the free function to be a no-op.
+  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+  // The pointer `buffer` is the data previously returned by an init invocation.
+  void (*free)(TfLiteContext* context, void* buffer);
+
+  // prepare is called when the inputs this node depends on have been resized.
+  // context->ResizeTensor() can be called to request output tensors to be
+  // resized.
+  //
+  // Returns kTfLiteOk on success.
+  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+  // Execute the node (should read node->inputs and output to node->outputs).
+  // Returns kTfLiteOk on success.
+  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+  // profiling_string is called during summarization of profiling information
+  // in order to group executions together. Providing a value here will cause a
+  // given op to appear multiple times is the profiling report. This is
+  // particularly useful for custom ops that can perform significantly
+  // different calculations depending on their `user-data`.
+  const char* (*profiling_string)(const TfLiteContext* context,
+                                  const TfLiteNode* node);
+
+  // Builtin codes. If this kernel refers to a builtin this is the code
+  // of the builtin. This is so we can do marshaling to other frameworks like
+  // NN API.
+  // Note: It is the responsibility of the registration binder to set this
+  // properly.
+  int32_t builtin_code;
+
+  // Custom op name. If the op is a builtin, this will be null.
+  // Note: It is the responsibility of the registration binder to set this
+  // properly.
+  // WARNING: This is an experimental interface that is subject to change.
+  const char* custom_name;
+
+  // The version of the op.
+  // Note: It is the responsibility of the registration binder to set this
+  // properly.
+  int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+  // Data that delegate needs to identify itself. This data is owned by the
+  // delegate. The delegate is owned in the user code, so the delegate is
+  // responsible for doing this when it is destroyed.
+  void* data_;
+
+  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+  // delegate a view of the current graph through TfLiteContext*. It typically
+  // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+  // to ask the TensorFlow lite runtime to create macro-nodes to represent
+  // delegated subgraphs of the original graph.
+  TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+  // Copy the data from delegate buffer handle to raw memory.
+  // This can be null if the delegate doesn't use its own buffer.
+  TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+                                       TfLiteDelegate* delegate,
+                                       TfLiteBufferHandle buffer_handle,
+                                       void* data, size_t size);
+
+  // Copy the data from raw memory to delegate buffer handle.
+  // This can be null if the delegate doesn't use its own buffer.
+  TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+                                     TfLiteDelegate* delegate,
+                                     TfLiteBufferHandle buffer_handle,
+                                     void* data, size_t size);
+
+  // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+  // this doesn't release the underlying resource (e.g. textures). The
+  // resources are either owned by application layer or the delegate.
+  // This can be null if the delegate doesn't use its own buffer.
+  void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+                           TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+  TfLiteDelegate* delegate;
+  TfLiteIntArray* nodes_to_replace;
+  TfLiteIntArray* input_tensors;
+  TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+}  // extern "C"
+#endif  // __cplusplus
+#endif  // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
similarity index 87%
rename from tensorflow/contrib/lite/context_test.cc
rename to tensorflow/contrib/lite/c/c_api_internal_test.cc
index 20d6f69..af398f3 100644
--- a/tensorflow/contrib/lite/context_test.cc
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -13,16 +13,15 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/testing/util.h"
 
 namespace tflite {
 
 // NOTE: this tests only the TfLiteIntArray part of context.
-// most of context.h is provided in the context of using it with interpreter.h
-// and interpreter.cc, so interpreter_test.cc tests context structures more
-// thoroughly.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
 
 TEST(IntArray, TestIntArrayCreate) {
   TfLiteIntArray* a = TfLiteIntArrayCreate(0);
@@ -69,7 +68,6 @@
 }  // namespace tflite
 
 int main(int argc, char** argv) {
-  ::tflite::LogToStderr();
   ::testing::InitGoogleTest(&argc, argv);
   return RUN_ALL_TESTS();
 }
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index c7f4df3..b86c281 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -12,480 +12,10 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-// This file defines a C API for implementing operations in tflite.
-// These operations can be defined using c++ but the interface between
-// the interpreter and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-//
-// Some abstractions in this file are created and managed by Interpreter.
+// Compatibility shim for moved header location.
 #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
 #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
 
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdlib.h>
-
-#ifdef __cplusplus
-extern "C" {
-#endif  // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
-  kTfLiteEigenContext = 0,     // include eigen_support.h to use.
-  kTfLiteGemmLowpContext = 1,  // include gemm_support.h to use.
-  kTfLiteEdgeTpuContext = 2,   // Placeholder for Edge TPU support.
-  kTfLiteMaxExternalContexts = 3
-} TfLiteExternalContextType;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
-  TfLiteExternalContextType type;
-  TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-// Forward declare so GetNode can use this is in Context.
-typedef struct _TfLiteRegistration TfLiteRegistration;
-typedef struct _TfLiteDelegate TfLiteDelegate;
-
-#define kOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
-  int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
-    __GNUC_MINOR__ >= 1
-  int data[0];
-#else
-  int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
-
-// Free memory of array `v`.
-void TfLiteIntArrayFree(TfLiteIntArray* v);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg)            \
-  do {                                                     \
-    if (!(value)) {                                        \
-      (context)->ReportError((context), __FILE__ " " msg); \
-      return kTfLiteError;                                 \
-    }                                                      \
-  } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a)                                          \
-  do {                                                                      \
-    if (!(a)) {                                                             \
-      (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
-                             __LINE__, #a);                                 \
-      return kTfLiteError;                                                  \
-    }                                                                       \
-  } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
-  do {                           \
-    if ((a) != kTfLiteOk) {      \
-      return kTfLiteError;       \
-    }                            \
-  } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b)                                       \
-  do {                                                                         \
-    if ((a) != (b)) {                                                          \
-      (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
-                             __LINE__, #a, #b, (a), (b));                      \
-      return kTfLiteError;                                                     \
-    }                                                                          \
-  } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
-  do {                                     \
-    if ((status) != kTfLiteOk) {           \
-      return status;                       \
-    }                                      \
-  } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
-  float re, im;  // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Types supported by tensor
-typedef enum {
-  kTfLiteNoType = 0,
-  kTfLiteFloat32 = 1,
-  kTfLiteInt32 = 2,
-  kTfLiteUInt8 = 3,
-  kTfLiteInt64 = 4,
-  kTfLiteString = 5,
-  kTfLiteBool = 6,
-  kTfLiteInt16 = 7,
-  kTfLiteComplex64 = 8,
-} TfLiteType;
-
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-//    real_value = scale * (quantized_value - zero_point);
-typedef struct {
-  float scale;
-  int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// A union of pointers that points to memory for a given tensor.
-typedef union {
-  int* i32;
-  int64_t* i64;
-  float* f;
-  char* raw;
-  const char* raw_const;
-  uint8_t* uint8;
-  bool* b;
-  int16_t* i16;
-  TfLiteComplex64* c64;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
-  kTfLiteMemNone = 0,
-  kTfLiteMmapRo,
-  kTfLiteArenaRw,
-  kTfLiteArenaRwPersistent,
-  kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
-  // The data type specification for data stored in `data`. This affects
-  // what member of `data` union should be used.
-  TfLiteType type;
-  // A union of data pointers. The appropriate type should be used for a typed
-  // tensor based on `type`.
-  TfLitePtrUnion data;
-  // A pointer to a structure representing the dimensionality interpretation
-  // that the buffer should have. NOTE: the product of elements of `dims`
-  // and the element datatype size should be equal to `bytes` below.
-  TfLiteIntArray* dims;
-  // Quantization information.
-  TfLiteQuantizationParams params;
-  // How memory is mapped
-  //  kTfLiteMmapRo: Memory mapped read only.
-  //  i.e. weights
-  //  kTfLiteArenaRw: Arena allocated read write memory
-  //  (i.e. temporaries, outputs).
-  TfLiteAllocationType allocation_type;
-  // The number of bytes required to store the data of this Tensor. I.e.
-  // (bytes of each element) * dims[0] * ... * dims[n-1].  For example, if
-  // type is kTfLiteFloat32 and dims = {3, 2} then
-  // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
-  size_t bytes;
-
-  // An opaque pointer to a tflite::MMapAllocation
-  const void* allocation;
-
-  // Null-terminated name of this tensor.
-  const char* name;
-
-  // The delegate which knows how to handle `buffer_handle`.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
-
-  // An integer buffer handle that can be handled by `delegate`.
-  // The value is valid only when delegate is not null.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteBufferHandle buffer_handle;
-
-  // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
-  // responsible to set data_is_stale to true.
-  // `delegate->CopyFromBufferHandle` can be called to copy the data from
-  // delegate buffer.
-  // WARNING: This is an // experimental interface that is subject to change.
-  bool data_is_stale;
-
-  // True if the tensor is a variable.
-  bool is_variable;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`;
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free memory of tensor `t`;
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
-                       TfLiteQuantizationParams quantization, char* buffer,
-                       size_t size, TfLiteAllocationType allocation_type,
-                       const void* allocation, bool is_variable,
-                       TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
-  // Inputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* inputs;
-
-  // Outputs to this node expressed as indices into the simulator's tensors.
-  TfLiteIntArray* outputs;
-
-  // Temporary tensors uses during the computations. This usually contains no
-  // tensors, but ops are allowed to change that if they need scratch space of
-  // any sort.
-  TfLiteIntArray* temporaries;
-
-  // Opaque data provided by the node implementer through `Registration.init`.
-  void* user_data;
-
-  // Opaque data provided to the node if the node is a builtin. This is usually
-  // a structure defined in builtin_op_data.h
-  void* builtin_data;
-
-  // Custom initial data. This is the opaque data provided in the flatbuffer.
-  // WARNING: This is an experimental interface that is subject to change.
-  const void* custom_initial_data;
-  int custom_initial_data_size;
-
-  // The pointer to the delegate. This is non-null only when the node is
-  // created by calling `interpreter.ModifyGraphWithDelegate`.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
-  // Number of tensors in the context.
-  size_t tensors_size;
-
-  // The execution plan contains a list of the node indices in execution
-  // order. execution_plan->size is the current number of nodes. And,
-  // execution_plan->data[0] is the first node that needs to be run.
-  // TfLiteDelegates can traverse the current execution plan by iterating
-  // through each member of this array and using GetNodeAndRegistration() to
-  // access details about a node. i.e.
-  // TfLiteIntArray* execution_plan;
-  // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
-  // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
-  //    int node_index = execution_plan->data[exec_index];
-  //    TfLiteNode* node;
-  //    TfLiteRegistration* reg;
-  //    context->GetNodeAndRegistration(context, node_index, &node, &reg);
-  // }
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
-                                   TfLiteIntArray** execution_plan);
-
-  // An array of tensors in the interpreter context (of length `tensors_size`)
-  TfLiteTensor* tensors;
-
-  // opaque full context ptr (an opaque c++ data structure)
-  void* impl_;
-
-  // Request memory pointer be resized. Updates dimensions on the tensor.
-  // NOTE: ResizeTensor takes ownership of newSize.
-  TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
-                               TfLiteIntArray* new_size);
-  // Request that a error be reported with format string msg.
-  void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
-  // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries.  If
-  // non-null, the value pointed to by `first_new_tensor_index` will be set to
-  // the index of the first new tensor.
-  TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
-                             int* first_new_tensor_index);
-
-  // Get a Tensor node by node_index.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
-                                         TfLiteNode** node,
-                                         TfLiteRegistration** registration);
-
-  // Replace ops with one or more stub delegate operations. This function
-  // does not take ownership of `nodes_to_replace`.
-  TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
-      struct TfLiteContext*, TfLiteRegistration registration,
-      const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
-
-  // Number of threads that are recommended to subsystems like gemmlowp and
-  // eigen.
-  int recommended_num_threads;
-
-  // Access external contexts by type.
-  // WARNING: This is an experimental interface that is subject to change.
-  TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
-                                               TfLiteExternalContextType);
-  // Set the value of a external context. Does not take ownership of the
-  // pointer.
-  // WARNING: This is an experimental interface that is subject to change.
-  void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
-                             TfLiteExternalContext*);
-} TfLiteContext;
-
-typedef struct _TfLiteRegistration {
-  // Initializes the op from serialized data.
-  // If a built-in op:
-  //   `buffer` is the op's params data (TfLiteLSTMParams*).
-  //   `length` is zero.
-  // If custom op:
-  //   `buffer` is the op's `custom_options`.
-  //   `length` is the size of the buffer.
-  //
-  // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
-  // or an instance of a struct).
-  //
-  // The returned pointer will be stored with the node in the `user_data` field,
-  // accessible within prepare and invoke functions below.
-  // NOTE: if the data is already in the desired format, simply implement this
-  // function to return `nullptr` and implement the free function to be a no-op.
-  void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
-  // The pointer `buffer` is the data previously returned by an init invocation.
-  void (*free)(TfLiteContext* context, void* buffer);
-
-  // prepare is called when the inputs this node depends on have been resized.
-  // context->ResizeTensor() can be called to request output tensors to be
-  // resized.
-  //
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
-  // Execute the node (should read node->inputs and output to node->outputs).
-  // Returns kTfLiteOk on success.
-  TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
-  // profiling_string is called during summarization of profiling information
-  // in order to group executions together. Providing a value here will cause a
-  // given op to appear multiple times is the profiling report. This is
-  // particularly useful for custom ops that can perform significantly
-  // different calculations depending on their `user-data`.
-  const char* (*profiling_string)(const TfLiteContext* context,
-                                  const TfLiteNode* node);
-
-  // Builtin codes. If this kernel refers to a builtin this is the code
-  // of the builtin. This is so we can do marshaling to other frameworks like
-  // NN API.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int32_t builtin_code;
-
-  // Custom op name. If the op is a builtin, this will be null.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  // WARNING: This is an experimental interface that is subject to change.
-  const char* custom_name;
-
-  // The version of the op.
-  // Note: It is the responsibility of the registration binder to set this
-  // properly.
-  int version;
-} TfLiteRegistration;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
-  // Data that delegate needs to identify itself. This data is owned by the
-  // delegate. The delegate is owned in the user code, so the delegate is
-  // responsible for doing this when it is destroyed.
-  void* data_;
-
-  // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
-  // delegate a view of the current graph through TfLiteContext*. It typically
-  // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
-  // to ask the TensorFlow lite runtime to create macro-nodes to represent
-  // delegated subgraphs of the original graph.
-  TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
-
-  // Copy the data from delegate buffer handle to raw memory.
-  // This can be null if the delegate doesn't use its own buffer.
-  TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
-                                       TfLiteDelegate* delegate,
-                                       TfLiteBufferHandle buffer_handle,
-                                       void* data, size_t size);
-
-  // Copy the data from raw memory to delegate buffer handle.
-  // This can be null if the delegate doesn't use its own buffer.
-  TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
-                                     TfLiteDelegate* delegate,
-                                     TfLiteBufferHandle buffer_handle,
-                                     void* data, size_t size);
-
-  // Free the Delegate Buffer Handle. Note: This only frees the handle, but
-  // this doesn't release the underlying resource (e.g. textures). The
-  // resources are either owned by application layer or the delegate.
-  // This can be null if the delegate doesn't use its own buffer.
-  void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
-                           TfLiteBufferHandle* handle);
-} TfLiteDelegate;
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
-  TfLiteDelegate* delegate;
-  TfLiteIntArray* nodes_to_replace;
-  TfLiteIntArray* input_tensors;
-  TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
-#ifdef __cplusplus
-}  // extern "C"
-#endif  // __cplusplus
 #endif  // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
index abe802e..ccda4c7 100644
--- a/tensorflow/contrib/lite/context_util.h
+++ b/tensorflow/contrib/lite/context_util.h
@@ -17,7 +17,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
 #define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD
new file mode 100644
index 0000000..e450053
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/BUILD
@@ -0,0 +1,57 @@
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+    name = "api",
+    srcs = [
+        "error_reporter.cc",
+        "flatbuffer_conversions.cc",
+        "op_resolver.cc",
+    ],
+    hdrs = [
+        "error_reporter.h",
+        "flatbuffer_conversions.h",
+        "op_resolver.h",
+    ],
+    copts = tflite_copts(),
+    deps = [
+        "//tensorflow/contrib/lite/c:c_api_internal",
+        "//tensorflow/contrib/lite/schema:schema_fbs",
+    ],
+)
+
+cc_test(
+    name = "error_reporter_test",
+    size = "small",
+    srcs = ["error_reporter_test.cc"],
+    deps = [
+        ":api",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_test(
+    name = "op_resolver_test",
+    size = "small",
+    srcs = ["op_resolver_test.cc"],
+    deps = [
+        ":api",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+cc_test(
+    name = "flatbuffer_conversions_test",
+    size = "small",
+    srcs = ["flatbuffer_conversions_test.cc"],
+    deps = [
+        ":api",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+        "@com_google_googletest//:gtest",
+    ],
+)
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000..423f83b
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+  va_list args;
+  va_start(args, format);
+  int code = Report(format, args);
+  va_end(args);
+  return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+  va_list args;
+  va_start(args, format);
+  int code = Report(format, args);
+  va_end(args);
+  return code;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h
new file mode 100644
index 0000000..a2f780b
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+//  ErrorReporter foo;
+//  foo.Report("test %d", 5);
+// or
+//  va_list args;
+//  foo.Report("test %d", args); // where args is va_list
+//
+// Subclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+  virtual ~ErrorReporter() {}
+  virtual int Report(const char* format, va_list args) = 0;
+  int Report(const char* format, ...);
+  int ReportError(void*, const char* format, ...);
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000..0463eee
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+  int Report(const char* format, va_list args) override {
+    vsnprintf(buffer_, kBufferSize, format, args);
+    return 0;
+  }
+  char* GetBuffer() { return buffer_; }
+
+ private:
+  static constexpr int kBufferSize = 256;
+  char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+  reporter->Report("Error: %d", 23);
+  EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
new file mode 100644
index 0000000..f4d2839
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -0,0 +1,626 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+
+namespace {
+
+// Copies the contents from the flatbuffer int vector `flatbuffer` into the
+// int array `buffer`. `flat_vector` and `buffer` represent the same
+// configuration operation for a given operation.
+void FlatBufferIntVectorToArray(int max_size_of_buffer,
+                                const flatbuffers::Vector<int32_t>* flat_vector,
+                                int* buffer, ErrorReporter* error_reporter) {
+  if (!flat_vector) {
+    error_reporter->Report("Input array not provided for operation.\n");
+  } else {
+    int num_dimensions = flat_vector->Length();
+    if (num_dimensions > max_size_of_buffer / sizeof(int)) {
+      error_reporter->Report(
+          "Found too many dimensions in the operation's input array.\n");
+    } else {
+      for (int i = 0; i < num_dimensions; ++i) {
+        buffer[i] = flat_vector->Get(i);
+      }
+    }
+  }
+}
+
+// Allocate a structure using malloc, but make sure the structure is a POD
+// structure that doesn't require constructors to run. The reason we do this,
+// is that Interpreter's C extension part will take ownership so destructors
+// will not be run during deallocation.
+template <class T>
+T* MallocPOD() {
+  static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+  return static_cast<T*>(malloc(sizeof(T)));
+}
+
+}  // namespace
+
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+                               ErrorReporter* error_reporter) {
+  switch (tensor_type) {
+    case TensorType_FLOAT32:
+      *type = kTfLiteFloat32;
+      break;
+    case TensorType_INT16:
+      *type = kTfLiteInt16;
+      break;
+    case TensorType_INT32:
+      *type = kTfLiteInt32;
+      break;
+    case TensorType_UINT8:
+      *type = kTfLiteUInt8;
+      break;
+    case TensorType_INT64:
+      *type = kTfLiteInt64;
+      break;
+    case TensorType_STRING:
+      *type = kTfLiteString;
+      break;
+    case TensorType_BOOL:
+      *type = kTfLiteBool;
+      break;
+    case TensorType_COMPLEX64:
+      *type = kTfLiteComplex64;
+      break;
+    default:
+      error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+                             EnumNameTensorType(tensor_type), tensor_type);
+      return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+                         ErrorReporter* error_reporter, void** builtin_data) {
+  auto parse_padding = [](Padding padding) {
+    switch (padding) {
+      case Padding_SAME:
+        return kTfLitePaddingSame;
+      case Padding_VALID:
+        return kTfLitePaddingValid;
+    }
+    return kTfLitePaddingUnknown;
+  };
+  auto parse_activation = [](ActivationFunctionType activation) {
+    switch (activation) {
+      case ActivationFunctionType_NONE:
+        return kTfLiteActNone;
+      case ActivationFunctionType_RELU:
+        return kTfLiteActRelu;
+      case ActivationFunctionType_RELU_N1_TO_1:
+        return kTfLiteActRelu1;
+      case ActivationFunctionType_RELU6:
+        return kTfLiteActRelu6;
+      case ActivationFunctionType_TANH:
+        return kTfLiteActTanh;
+      case ActivationFunctionType_SIGN_BIT:
+        return kTfLiteActSignBit;
+    }
+    return kTfLiteActNone;
+  };
+  auto parseLSHProjectionType = [](LSHProjectionType type) {
+    switch (type) {
+      case LSHProjectionType_SPARSE:
+        return kTfLiteLshProjectionSparse;
+      case LSHProjectionType_DENSE:
+        return kTfLiteLshProjectionDense;
+      default:
+        return kTfLiteLshProjectionUnknown;
+    }
+  };
+  auto parseCombinerType = [](CombinerType type) {
+    switch (type) {
+      case CombinerType_MEAN:
+        return kTfLiteCombinerTypeMean;
+      case CombinerType_SQRTN:
+        return kTfLiteCombinerTypeSqrtn;
+      case CombinerType_SUM:
+      default:
+        return kTfLiteCombinerTypeSum;
+    }
+  };
+
+  *builtin_data = nullptr;
+  switch (op_type) {
+    case BuiltinOperator_CONV_2D: {
+      TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+      if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+        params->padding = parse_padding(conv_params->padding());
+        params->stride_width = conv_params->stride_w();
+        params->stride_height = conv_params->stride_h();
+        params->activation =
+            parse_activation(conv_params->fused_activation_function());
+
+        params->dilation_width_factor = conv_params->dilation_w_factor();
+        params->dilation_height_factor = conv_params->dilation_h_factor();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_CAST: {
+      TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+      if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+        auto in_status =
+            ConvertTensorType(schema_params->in_data_type(),
+                              &params->in_data_type, error_reporter);
+        auto out_status =
+            ConvertTensorType(schema_params->out_data_type(),
+                              &params->out_data_type, error_reporter);
+        if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+          free(params);
+          return kTfLiteError;
+        }
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_LSH_PROJECTION: {
+      TfLiteLSHProjectionParams* params =
+          MallocPOD<TfLiteLSHProjectionParams>();
+      if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+        params->type = parseLSHProjectionType(lshParams->type());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_AVERAGE_POOL_2D:
+    case BuiltinOperator_MAX_POOL_2D:
+    case BuiltinOperator_L2_POOL_2D: {
+      TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+      if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+        params->padding = parse_padding(pool_params->padding());
+        params->stride_width = pool_params->stride_w();
+        params->stride_height = pool_params->stride_h();
+        params->filter_width = pool_params->filter_width();
+        params->filter_height = pool_params->filter_height();
+        params->activation =
+            parse_activation(pool_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_DEPTHWISE_CONV_2D: {
+      TfLiteDepthwiseConvParams* params =
+          MallocPOD<TfLiteDepthwiseConvParams>();
+      if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+        params->padding = parse_padding(conv_params->padding());
+        params->stride_width = conv_params->stride_w();
+        params->stride_height = conv_params->stride_h();
+        params->depth_multiplier = conv_params->depth_multiplier();
+        params->activation =
+            parse_activation(conv_params->fused_activation_function());
+
+        params->dilation_width_factor = conv_params->dilation_w_factor();
+        params->dilation_height_factor = conv_params->dilation_h_factor();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SVDF: {
+      TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+      if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+        params->rank = svdf_params->rank();
+        params->activation =
+            parse_activation(svdf_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
+    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+      TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+      if (auto* sequence_rnn_params =
+              op->builtin_options_as_SequenceRNNOptions()) {
+        params->activation =
+            parse_activation(sequence_rnn_params->fused_activation_function());
+        params->time_major = sequence_rnn_params->time_major();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_RNN: {
+      TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+      if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+        params->activation =
+            parse_activation(rnn_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+      TfLiteEmbeddingLookupSparseParams* params =
+          MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+      if (auto* embedding_params =
+              op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+        params->combiner = parseCombinerType(embedding_params->combiner());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_FULLY_CONNECTED: {
+      TfLiteFullyConnectedParams* params =
+          MallocPOD<TfLiteFullyConnectedParams>();
+      if (auto* fully_connected_params =
+              op->builtin_options_as_FullyConnectedOptions()) {
+        params->activation = parse_activation(
+            fully_connected_params->fused_activation_function());
+        switch (fully_connected_params->weights_format()) {
+          case FullyConnectedOptionsWeightsFormat_DEFAULT:
+            params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+            break;
+          case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+            params->weights_format =
+                kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+            break;
+          default:
+            error_reporter->Report("Unhandled fully-connected weights format.");
+            return kTfLiteError;
+        }
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_HASHTABLE_LOOKUP:
+      // no-op.
+      break;
+    case BuiltinOperator_SOFTMAX: {
+      TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+      if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+        params->beta = softmax_params->beta();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_CONCATENATION: {
+      TfLiteConcatenationParams* params =
+          MallocPOD<TfLiteConcatenationParams>();
+      if (auto* concatenation_params =
+              op->builtin_options_as_ConcatenationOptions()) {
+        params->activation =
+            parse_activation(concatenation_params->fused_activation_function());
+        params->axis = concatenation_params->axis();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_MUL: {
+      auto* params = MallocPOD<TfLiteMulParams>();
+      if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+        params->activation =
+            parse_activation(schema_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_ADD: {
+      auto* params = MallocPOD<TfLiteAddParams>();
+      if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+        params->activation =
+            parse_activation(schema_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_DIV: {
+      auto* params = MallocPOD<TfLiteDivParams>();
+      if (auto* schema_params = op->builtin_options_as_DivOptions()) {
+        params->activation =
+            parse_activation(schema_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SUB: {
+      auto* params = MallocPOD<TfLiteSubParams>();
+      if (auto* schema_params = op->builtin_options_as_SubOptions()) {
+        params->activation =
+            parse_activation(schema_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_L2_NORMALIZATION: {
+      auto* params = MallocPOD<TfLiteL2NormParams>();
+      if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+        params->activation =
+            parse_activation(schema_params->fused_activation_function());
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+      auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+      if (auto* schema_params =
+              op->builtin_options_as_LocalResponseNormalizationOptions()) {
+        params->radius = schema_params->radius();
+        params->bias = schema_params->bias();
+        params->alpha = schema_params->alpha();
+        params->beta = schema_params->beta();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
+    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+    case BuiltinOperator_LSTM: {
+      TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+      if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+        params->activation =
+            parse_activation(lstm_params->fused_activation_function());
+        params->cell_clip = lstm_params->cell_clip();
+        params->proj_clip = lstm_params->proj_clip();
+        switch (lstm_params->kernel_type()) {
+          case LSTMKernelType_FULL:
+            params->kernel_type = kTfLiteLSTMFullKernel;
+            break;
+          case LSTMKernelType_BASIC:
+            params->kernel_type = kTfLiteLSTMBasicKernel;
+            break;
+        }
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_RESIZE_BILINEAR: {
+      auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+      if (auto* schema_params =
+              op->builtin_options_as_ResizeBilinearOptions()) {
+        params->align_corners = schema_params->align_corners();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_RESHAPE: {
+      auto* params = MallocPOD<TfLiteReshapeParams>();
+      if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+        auto* new_shape = schema_params->new_shape();
+        FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
+                                   params->shape, error_reporter);
+        params->num_dimensions = new_shape->Length();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SKIP_GRAM: {
+      TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+      if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+        params->ngram_size = skip_gram_params->ngram_size();
+        params->max_skip_size = skip_gram_params->max_skip_size();
+        params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SPACE_TO_DEPTH: {
+      auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+      if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+        params->block_size = schema_params->block_size();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_GATHER: {
+      TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+      params->axis = 0;
+      if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
+        params->axis = gather_params->axis();
+      }
+
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_MEAN:
+    case BuiltinOperator_REDUCE_MAX:
+    case BuiltinOperator_REDUCE_MIN:
+    case BuiltinOperator_REDUCE_PROD:
+    case BuiltinOperator_REDUCE_ANY:
+    case BuiltinOperator_SUM: {
+      auto* params = MallocPOD<TfLiteReducerParams>();
+      if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
+        params->keep_dims = schema_params->keep_dims();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SPLIT: {
+      auto* params = MallocPOD<TfLiteSplitParams>();
+      if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
+        params->num_splits = schema_params->num_splits();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SQUEEZE: {
+      auto* params = MallocPOD<TfLiteSqueezeParams>();
+      if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+        const auto& squeeze_dims = schema_params->squeeze_dims();
+        FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+                                   params->squeeze_dims, error_reporter);
+        params->num_squeeze_dims = squeeze_dims->Length();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_STRIDED_SLICE: {
+      auto* params = MallocPOD<TfLiteStridedSliceParams>();
+      if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+        params->begin_mask = schema_params->begin_mask();
+        params->end_mask = schema_params->end_mask();
+        params->ellipsis_mask = schema_params->ellipsis_mask();
+        params->new_axis_mask = schema_params->new_axis_mask();
+        params->shrink_axis_mask = schema_params->shrink_axis_mask();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_ARG_MAX: {
+      auto* params = MallocPOD<TfLiteArgMaxParams>();
+      if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
+        ConvertTensorType(schema_params->output_type(), &params->output_type,
+                          error_reporter);
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_ARG_MIN: {
+      auto* params = MallocPOD<TfLiteArgMinParams>();
+      if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+        ConvertTensorType(schema_params->output_type(), &params->output_type,
+                          error_reporter);
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_TRANSPOSE_CONV: {
+      TfLiteTransposeConvParams* params =
+          MallocPOD<TfLiteTransposeConvParams>();
+      if (auto* transpose_conv_params =
+              op->builtin_options_as_TransposeConvOptions()) {
+        params->padding = parse_padding(transpose_conv_params->padding());
+        params->stride_width = transpose_conv_params->stride_w();
+        params->stride_height = transpose_conv_params->stride_h();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SPARSE_TO_DENSE: {
+      TfLiteSparseToDenseParams* params =
+          MallocPOD<TfLiteSparseToDenseParams>();
+      if (auto* sparse_to_dense_params =
+              op->builtin_options_as_SparseToDenseOptions()) {
+        params->validate_indices = sparse_to_dense_params->validate_indices();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_SHAPE: {
+      auto* params = MallocPOD<TfLiteShapeParams>();
+      if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+        ConvertTensorType(schema_params->out_type(), &params->out_type,
+                          error_reporter);
+      }
+      *builtin_data = static_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_PACK: {
+      TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+      if (auto* pack_params = op->builtin_options_as_PackOptions()) {
+        params->values_count = pack_params->values_count();
+        params->axis = pack_params->axis();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_DELEGATE: {
+      // TODO(ycling): Revisit when supporting saving delegated models.
+      error_reporter->Report("DELEGATE op shouldn't exist in model.");
+      return kTfLiteError;
+    }
+    case BuiltinOperator_FAKE_QUANT: {
+      auto* params = MallocPOD<TfLiteFakeQuantParams>();
+      if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+        params->min = schema_params->min();
+        params->max = schema_params->max();
+        params->num_bits = schema_params->num_bits();
+        params->narrow_range = schema_params->narrow_range();
+      }
+      *builtin_data = static_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_ONE_HOT: {
+      auto* params = MallocPOD<TfLiteOneHotParams>();
+      if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+        params->axis = schema_params->axis();
+      }
+      *builtin_data = static_cast<void*>(params);
+      break;
+    }
+    case BuiltinOperator_UNPACK: {
+      TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+      if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+        params->num = unpack_params->num();
+        params->axis = unpack_params->axis();
+      }
+      *builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
+
+    // Below are the ops with no builtin_data strcture.
+    case BuiltinOperator_BATCH_TO_SPACE_ND:
+    // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+    // ok for now, since there is no call implementation either.
+    case BuiltinOperator_CALL:
+    case BuiltinOperator_CONCAT_EMBEDDINGS:
+    case BuiltinOperator_CUSTOM:
+    case BuiltinOperator_DEQUANTIZE:
+    case BuiltinOperator_EMBEDDING_LOOKUP:
+    case BuiltinOperator_EQUAL:
+    case BuiltinOperator_EXP:
+    case BuiltinOperator_EXPAND_DIMS:
+    case BuiltinOperator_FLOOR:
+    case BuiltinOperator_GREATER:
+    case BuiltinOperator_GREATER_EQUAL:
+    case BuiltinOperator_LESS:
+    case BuiltinOperator_LESS_EQUAL:
+    case BuiltinOperator_LOG:
+    case BuiltinOperator_LOGISTIC:
+    case BuiltinOperator_LOG_SOFTMAX:
+    case BuiltinOperator_MAXIMUM:
+    case BuiltinOperator_MINIMUM:
+    case BuiltinOperator_NEG:
+    case BuiltinOperator_NOT_EQUAL:
+    case BuiltinOperator_PAD:
+    case BuiltinOperator_PADV2:
+    case BuiltinOperator_PRELU:
+    case BuiltinOperator_RELU:
+    case BuiltinOperator_RELU6:
+    case BuiltinOperator_RELU_N1_TO_1:
+    case BuiltinOperator_RSQRT:
+    case BuiltinOperator_SELECT:
+    case BuiltinOperator_SIN:
+    case BuiltinOperator_SLICE:
+    case BuiltinOperator_SPACE_TO_BATCH_ND:
+    case BuiltinOperator_SQRT:
+    case BuiltinOperator_TANH:
+    case BuiltinOperator_TILE:
+    case BuiltinOperator_TOPK_V2:
+    case BuiltinOperator_TRANSPOSE:
+    case BuiltinOperator_POW:
+    case BuiltinOperator_LOGICAL_OR:
+    case BuiltinOperator_LOGICAL_AND:
+    case BuiltinOperator_LOGICAL_NOT:
+    case BuiltinOperator_FLOOR_DIV:
+    case BuiltinOperator_SQUARE:
+      break;
+  }
+  return kTfLiteOk;
+}  // NOLINT[readability/fn_size]
+
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
new file mode 100644
index 0000000..4dec6f9
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+
+// These functions transform codes and data structures that are defined in the
+// flatbuffer serialization format into in-memory values that are used by the
+// runtime API and interpreter.
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
+// calling function has to pass in an allocator object, and this allocator
+// will be called to reserve space for the output data. If the calling
+// function's allocator reserves memory on the heap, then it's the calling
+// function's responsibility to free it.
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+                         ErrorReporter* error_reporter, void** builtin_data);
+
+// Converts the tensor data type used in the flat buffer to the representation
+// used by the runtime.
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+                               ErrorReporter* error_reporter);
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
new file mode 100644
index 0000000..b12bdf4
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+namespace {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+  MockErrorReporter() : buffer_size_(0) {}
+  int Report(const char* format, va_list args) override {
+    buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+    return buffer_size_;
+  }
+  char* GetBuffer() { return buffer_; }
+  int GetBufferSize() { return buffer_size_; }
+
+ private:
+  static constexpr int kBufferSize = 256;
+  char buffer_[kBufferSize];
+  int buffer_size_;
+};
+
+}  // namespace
+
+TEST(FlatbufferConversions, TestParseOpDataConv) {
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<void> conv_options =
+      CreateConv2DOptions(builder, Padding_SAME, 1, 2,
+                          ActivationFunctionType_RELU, 3, 4)
+          .Union();
+  flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
+      builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
+      nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
+  builder.Finish(conv_offset);
+  void* conv_pointer = builder.GetBufferPointer();
+  const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
+  void* output_data = nullptr;
+  EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
+                                   &output_data));
+  EXPECT_NE(nullptr, output_data);
+  TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
+  EXPECT_EQ(kTfLitePaddingSame, params->padding);
+  EXPECT_EQ(1, params->stride_width);
+  EXPECT_EQ(2, params->stride_height);
+  EXPECT_EQ(kTfLiteActRelu, params->activation);
+  EXPECT_EQ(3, params->dilation_width_factor);
+  EXPECT_EQ(4, params->dilation_height_factor);
+  free(output_data);
+}
+
+TEST(FlatbufferConversions, TestParseOpDataCustom) {
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<void> null_options;
+  flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect(
+      builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr,
+      CustomOptionsFormat_FLEXBUFFERS, nullptr);
+  builder.Finish(custom_offset);
+  void* custom_pointer = builder.GetBufferPointer();
+  const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
+  void* output_data = nullptr;
+  EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
+                                   &output_data));
+  EXPECT_EQ(nullptr, output_data);
+}
+
+TEST(FlatbufferConversions, TestConvertTensorType) {
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+  TfLiteType type;
+  EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter));
+  EXPECT_EQ(kTfLiteFloat32, type);
+}
+
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc
new file mode 100644
index 0000000..55ee924
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+namespace tflite {
+
+TfLiteStatus GetRegistrationFromOpCode(
+    const OperatorCode* opcode, const OpResolver& op_resolver,
+    ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
+  TfLiteStatus status = kTfLiteOk;
+  *registration = nullptr;
+  auto builtin_code = opcode->builtin_code();
+  int version = opcode->version();
+
+  if (builtin_code > BuiltinOperator_MAX ||
+      builtin_code < BuiltinOperator_MIN) {
+    error_reporter->Report(
+        "Op builtin_code out of range: %d. Are you using old TFLite binary "
+        "with newer model?",
+        builtin_code);
+    status = kTfLiteError;
+  } else if (builtin_code != BuiltinOperator_CUSTOM) {
+    *registration = op_resolver.FindOp(builtin_code, version);
+    if (*registration == nullptr) {
+      error_reporter->Report(
+          "Didn't find op for builtin opcode '%s' version '%d'\n",
+          EnumNameBuiltinOperator(builtin_code), version);
+      status = kTfLiteError;
+    }
+  } else if (!opcode->custom_code()) {
+    error_reporter->Report(
+        "Operator with CUSTOM builtin_code has no custom_code.\n");
+    status = kTfLiteError;
+  } else {
+    const char* name = opcode->custom_code()->c_str();
+    *registration = op_resolver.FindOp(name, version);
+    if (*registration == nullptr) {
+      error_reporter->Report(
+          "Didn't find custom op for name '%s' with version %d\n", name,
+          version);
+      status = kTfLiteError;
+    }
+  }
+  return status;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h
new file mode 100644
index 0000000..5f5e6b2
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+  // Finds the op registration for a builtin operator by enum code.
+  virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+                                           int version) const = 0;
+  // Finds the op registration of a custom operator by op name.
+  virtual const TfLiteRegistration* FindOp(const char* op,
+                                           int version) const = 0;
+  virtual ~OpResolver() {}
+};
+
+// Handles the logic for converting between an OperatorCode structure extracted
+// from a flatbuffer and information about a registered operator implementation.
+TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
+                                       const OpResolver& op_resolver,
+                                       ErrorReporter* error_reporter,
+                                       const TfLiteRegistration** registration);
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
new file mode 100644
index 0000000..1674631
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+  // Do nothing.
+  return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+  // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+  return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+  const TfLiteRegistration* FindOp(BuiltinOperator op,
+                                   int version) const override {
+    if (op == BuiltinOperator_CONV_2D) {
+      static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+                                     MockInvoke};
+      return &r;
+    } else {
+      return nullptr;
+    }
+  }
+  const TfLiteRegistration* FindOp(const char* op, int version) const override {
+    if (strcmp(op, "mock_custom") == 0) {
+      static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+                                     MockInvoke};
+      return &r;
+    } else {
+      return nullptr;
+    }
+  }
+};
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+  MockErrorReporter() : buffer_size_(0) {}
+  int Report(const char* format, va_list args) override {
+    buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+    return buffer_size_;
+  }
+  char* GetBuffer() { return buffer_; }
+  int GetBufferSize() { return buffer_size_; }
+
+ private:
+  static constexpr int kBufferSize = 256;
+  char buffer_[kBufferSize];
+  int buffer_size_;
+};
+
+}  // namespace
+
+TEST(OpResolver, TestResolver) {
+  MockOpResolver mock_resolver;
+  OpResolver* resolver = &mock_resolver;
+
+  const TfLiteRegistration* registration =
+      resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+  EXPECT_NE(nullptr, registration);
+  EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+  EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+  EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+  registration = resolver->FindOp(BuiltinOperator_CAST, 0);
+  EXPECT_EQ(nullptr, registration);
+
+  registration = resolver->FindOp("mock_custom", 0);
+  EXPECT_NE(nullptr, registration);
+  EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+  EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+  EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+  registration = resolver->FindOp("nonexistent_custom", 0);
+  EXPECT_EQ(nullptr, registration);
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
+  MockOpResolver mock_resolver;
+  OpResolver* resolver = &mock_resolver;
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<OperatorCode> conv_offset =
+      CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
+  builder.Finish(conv_offset);
+  void* conv_pointer = builder.GetBufferPointer();
+  const OperatorCode* conv_code =
+      flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+  const TfLiteRegistration* registration = nullptr;
+  EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+                                                 &registration));
+  EXPECT_NE(nullptr, registration);
+  EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+  EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+  EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+  EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCast) {
+  MockOpResolver mock_resolver;
+  OpResolver* resolver = &mock_resolver;
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<OperatorCode> conv_offset =
+      CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0);
+  builder.Finish(conv_offset);
+  void* conv_pointer = builder.GetBufferPointer();
+  const OperatorCode* conv_code =
+      flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+  const TfLiteRegistration* registration = nullptr;
+  EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+                                                    reporter, &registration));
+  EXPECT_EQ(nullptr, registration);
+  EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) {
+  MockOpResolver mock_resolver;
+  OpResolver* resolver = &mock_resolver;
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+      builder, BuiltinOperator_CUSTOM, "mock_custom", 0);
+  builder.Finish(conv_offset);
+  void* conv_pointer = builder.GetBufferPointer();
+  const OperatorCode* conv_code =
+      flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+  const TfLiteRegistration* registration = nullptr;
+  EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+                                                 &registration));
+  EXPECT_NE(nullptr, registration);
+  EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+  EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+  EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+  EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) {
+  MockOpResolver mock_resolver;
+  OpResolver* resolver = &mock_resolver;
+  MockErrorReporter mock_reporter;
+  ErrorReporter* reporter = &mock_reporter;
+
+  flatbuffers::FlatBufferBuilder builder;
+  flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+      builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0);
+  builder.Finish(conv_offset);
+  void* conv_pointer = builder.GetBufferPointer();
+  const OperatorCode* conv_code =
+      flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+  const TfLiteRegistration* registration = nullptr;
+  EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+                                                    reporter, &registration));
+  EXPECT_EQ(nullptr, registration);
+  EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index b6b2357..bf5d918 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,6 +16,7 @@
     deps = [
         ":util",
         "//tensorflow/c:c_api_internal",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite:kernel_api",
     ] + select({
         "//tensorflow:android": [
@@ -54,6 +55,7 @@
         ":delegate_data",
         ":kernel",
         ":util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite:kernel_api",
         "//tensorflow/contrib/lite:util",
     ] + select({
@@ -104,6 +106,7 @@
         ":delegate_data",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/testing:util",
         "@com_google_googletest//:gtest",
     ],
@@ -117,6 +120,7 @@
         ":delegate_data",
         ":util",
         "@flatbuffers",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite:kernel_api",
         "//tensorflow/contrib/lite:string",
         "//tensorflow/contrib/lite/kernels:kernel_util",
@@ -170,6 +174,7 @@
     hdrs = ["util.h"],
     deps = [
         "//tensorflow/c:c_api_internal",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite:kernel_api",
     ] + select({
         "//tensorflow:android": [
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index a28329a..aaaa045 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -17,7 +17,7 @@
 
 #include <map>
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/core/framework/tensor.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6d15ba4..70f3c15 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
 #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index b3a0ffc..def0633 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,7 +16,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/testing/util.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index eb47f46..43ec5d5 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -72,6 +72,26 @@
 
   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+  ASSERT_EQ(GetType(8), kTfLiteFloat32);
+}
+
+TEST_F(DelegateTest, NonFloatTypeInference) {
+  AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
+
+  AddTfOp(testing::kAdd, {0, 1}, {2});
+
+  ConfigureDelegate();
+
+  SetShape(0, {2, 2});
+  SetTypedValues<int>(0, {1, 2, 3, 4});
+  SetShape(1, {2, 2});
+  SetTypedValues<int>(1, {4, 3, 2, 1});
+
+  ASSERT_TRUE(Invoke());
+
+  ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
+  ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
+  ASSERT_EQ(GetType(2), kTfLiteInt32);
 }
 
 TEST_F(DelegateTest, MixedGraph) {
@@ -137,6 +157,34 @@
   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
 }
 
+TEST_F(DelegateTest, MultipleInvokeCalls) {
+  // Call Invoke() multiple times on the same model.
+  AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
+  AddTfLiteMulOp({0, 1}, {2});
+
+  ConfigureDelegate();
+
+  SetShape(0, {2, 2, 1});
+  SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+  SetShape(1, {2, 2, 1});
+  SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
+
+  ASSERT_TRUE(Invoke());
+
+  ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+  ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
+
+  SetShape(0, {2, 2, 1});
+  SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
+  SetShape(1, {2, 2, 1});
+  SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
+
+  ASSERT_TRUE(Invoke());
+
+  ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+  ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
+}
+
 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
   // Build a graph, configure the delegate and set inputs.
   {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index f8467c7..274c3c0 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -16,7 +16,7 @@
 
 #include "flatbuffers/flexbuffers.h"  // flatbuffers
 #include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/context_util.h"
 #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
 #include "tensorflow/contrib/lite/delegates/eager/util.h"
@@ -278,7 +278,7 @@
     TfLiteTensor* tensor = &context->tensors[tensor_index];
     TF_LITE_ENSURE_OK(
         context,
-        CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+        CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
     tensor->buffer_handle = tensor_index;
     tensor->data_is_stale = true;
   }
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
index 100672c..2478abc 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
 #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 namespace eager {
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index b8c9e26..8584999 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -25,19 +25,6 @@
 
 bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
 
-void EagerModelTest::SetValues(int tensor_index,
-                               const std::vector<float>& values) {
-  float* v = interpreter_->typed_tensor<float>(tensor_index);
-  for (float f : values) {
-    *v++ = f;
-  }
-}
-
-std::vector<float> EagerModelTest::GetValues(int tensor_index) {
-  TfLiteTensor* o = interpreter_->tensor(tensor_index);
-  return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
-}
-
 void EagerModelTest::SetShape(int tensor_index,
                               const std::vector<int>& values) {
   ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
@@ -54,13 +41,21 @@
   return result;
 }
 
+TfLiteType EagerModelTest::GetType(int tensor_index) {
+  return interpreter_->tensor(tensor_index)->type;
+}
+
 void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
                                 const std::vector<int>& outputs,
-                                const TfLiteType& type,
-                                const std::vector<int>& dims) {
+                                TfLiteType type, const std::vector<int>& dims) {
   interpreter_->AddTensors(num_tensors);
   for (int i = 0; i < num_tensors; ++i) {
     TfLiteQuantizationParams quant;
+    // Suppress explicit output type specification to ensure type inference
+    // works properly.
+    if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
+      type = kTfLiteFloat32;
+    }
     CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
                                                         /*name=*/"",
                                                         /*dims=*/dims, quant),
@@ -101,18 +96,26 @@
     return " attr{ key: '" + key + "' value {" + value + "}}";
   };
 
+  // Crude type attribution, will need fleshing out as more tests are added.
+  // TODO(b/113613439): Use nodedef string utilities to properly handle
+  // all types.
+  string type_attribute = attr("T", "type: DT_FLOAT");
+  if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) {
+    type_attribute = attr("T", "type: DT_INT32");
+  }
+
   if (op == kUnpack) {
-    string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
-                        attr("axis", "i: 0");
+    string attributes =
+        type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
     AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
   } else if (op == kIdentity) {
-    string attributes = attr("T", "type: DT_FLOAT");
+    string attributes = type_attribute;
     AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
   } else if (op == kAdd) {
-    string attributes = attr("T", "type: DT_FLOAT");
+    string attributes = type_attribute;
     AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
   } else if (op == kMul) {
-    string attributes = attr("T", "type: DT_FLOAT");
+    string attributes = type_attribute;
     AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
   } else if (op == kNonExistent) {
     AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
index 0eab9e1..816db41 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -44,11 +44,30 @@
 
   bool Invoke();
 
+  // Sets the (typed) tensor's values at the given index.
+  template <typename T>
+  void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+    memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+           values.size() * sizeof(T));
+  }
+
+  // Returns the (typed) tensor's values at the given index.
+  template <typename T>
+  std::vector<T> GetTypedValues(int tensor_index) {
+    const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+    const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+    return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+  }
+
   // Sets the tensor's values at the given index.
-  void SetValues(int tensor_index, const std::vector<float>& values);
+  void SetValues(int tensor_index, const std::vector<float>& values) {
+    SetTypedValues<float>(tensor_index, values);
+  }
 
   // Returns the tensor's values at the given index.
-  std::vector<float> GetValues(int tensor_index);
+  std::vector<float> GetValues(int tensor_index) {
+    return GetTypedValues<float>(tensor_index);
+  }
 
   // Sets the tensor's shape at the given index.
   void SetShape(int tensor_index, const std::vector<int>& values);
@@ -56,13 +75,16 @@
   // Returns the tensor's shape at the given index.
   std::vector<int> GetShape(int tensor_index);
 
+  // Returns the tensor's type at the given index.
+  TfLiteType GetType(int tensor_index);
+
   const TestErrorReporter& error_reporter() const { return error_reporter_; }
 
   // Adds `num_tensor` tensors to the model. `inputs` contains the indices of
   // the input tensors and `outputs` contains the indices of the output
   // tensors. All tensors are set to have `type` and `dims`.
   void AddTensors(int num_tensors, const std::vector<int>& inputs,
-                  const std::vector<int>& outputs, const TfLiteType& type,
+                  const std::vector<int>& outputs, TfLiteType type,
                   const std::vector<int>& dims);
 
   // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 4426c65..051246b 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -26,8 +26,17 @@
   return kTfLiteOk;
 }
 
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
-                       TfLiteTensor* tensor) {
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+                              const tensorflow::Tensor& src,
+                              TfLiteTensor* tensor) {
+  tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
+  if (tensor->type == kTfLiteNoType) {
+    context->ReportError(context,
+                         "TF Lite does not support TensorFlow data type: %s",
+                         DataTypeString(src.dtype()).c_str());
+    return kTfLiteError;
+  }
+
   int num_dims = src.dims();
   TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
   for (int j = 0; j < num_dims; ++j) {
@@ -68,5 +77,28 @@
   }
 }
 
+TfLiteType GetTensorFlowLiteType(TF_DataType type) {
+  switch (type) {
+    case TF_FLOAT:
+      return kTfLiteFloat32;
+    case TF_INT16:
+      return kTfLiteInt16;
+    case TF_INT32:
+      return kTfLiteInt32;
+    case TF_UINT8:
+      return kTfLiteUInt8;
+    case TF_INT64:
+      return kTfLiteInt64;
+    case TF_COMPLEX64:
+      return kTfLiteComplex64;
+    case TF_STRING:
+      return kTfLiteString;
+    case TF_BOOL:
+      return kTfLiteBool;
+    default:
+      return kTfLiteNoType;
+  }
+}
+
 }  // namespace eager
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index a9407be..930cb99 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -16,7 +16,7 @@
 #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
 
 #include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status.h"
 
@@ -28,14 +28,19 @@
 TfLiteStatus ConvertStatus(TfLiteContext* context,
                            const tensorflow::Status& status);
 
-// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an
-// error and returns kTfLiteError if the shape can't be converted.
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
-                       TfLiteTensor* tensor);
+// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite
+// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't
+// be converted.
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+                              const tensorflow::Tensor& src,
+                              TfLiteTensor* tensor);
 
 // Returns the TF C API Data type that corresponds to the given TfLiteType.
 TF_DataType GetTensorFlowDataType(TfLiteType type);
 
+// Returns the TfLiteType that corresponds to the given TF C API Data type.
+TfLiteType GetTensorFlowLiteType(TF_DataType);
+
 }  // namespace eager
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 53378a1..aebc911 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -26,6 +26,7 @@
 namespace {
 
 using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
 using tensorflow::Tensor;
 using ::testing::ElementsAre;
 
@@ -71,27 +72,41 @@
   EXPECT_TRUE(context.error.empty());
 }
 
-TEST(UtilTest, CopyShape) {
+TEST(UtilTest, CopyShapeAndType) {
   TestContext context;
   context.ReportError = ReportError;
   context.ResizeTensor = ResizeTensor;
 
   TfLiteTensor dst;
 
-  EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
+  EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk);
   EXPECT_THAT(context.new_size, ElementsAre(0));
+  EXPECT_EQ(dst.type, kTfLiteFloat32);
 
-  EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk);
+  EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst),
+            kTfLiteOk);
   EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+  EXPECT_EQ(dst.type, kTfLiteFloat32);
 
-  EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
+  EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst),
+            kTfLiteOk);
+  EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+  EXPECT_EQ(dst.type, kTfLiteInt32);
+
+  EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
             kTfLiteError);
   EXPECT_EQ(context.error,
             "Dimension value in TensorFlow shape is larger than supported by "
             "TF Lite");
+
+  EXPECT_EQ(
+      CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst),
+      kTfLiteError);
+  EXPECT_EQ(context.error,
+            "TF Lite does not support TensorFlow data type: half");
 }
 
-TEST(UtilTest, TypeConversions) {
+TEST(UtilTest, TypeConversionsFromTFLite) {
   EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
   EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
   EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
@@ -103,6 +118,19 @@
   EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
 }
 
+TEST(UtilTest, TypeConversionsFromTensorFlow) {
+  EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
+  EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
+  EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
+  EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
+  EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
+  EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
+  EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
+  EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
+  EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
+  EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
+}
+
 }  // namespace
 }  // namespace eager
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
index 954955f..4e7b294 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/BUILD
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -13,6 +13,7 @@
     deps = [
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:kernel_api",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:kernel_util",
         "//tensorflow/contrib/lite/nnapi:nnapi_lib",
     ],
@@ -29,6 +30,7 @@
     deps = [
         ":nnapi_delegate",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index 980a1cb..c6587b3 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/contrib/lite/allocation.h"
 #include "tensorflow/contrib/lite/builtin_op_data.h"
 #include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/context_util.h"
 #include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -1115,6 +1115,14 @@
     CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
                           nn_model_.get(), inputs.size(), inputs.data(),
                           outputs.size(), outputs.data()));
+
+    // Set relaxed computation mode for fp32 if possible.
+    if (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) {
+      CHECK_NN(context,
+               ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+                   nn_model_.get(), context->allow_fp32_relax_to_fp16));
+    }
+
     // Finalize the model
     CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get()));
 
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 44cca2f..4852b76 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
 #define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 4b01aef..9626c54 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -40,13 +40,15 @@
  public:
   FloatAddOpModel(const TensorData& input1, const TensorData& input2,
                   const TensorData& output,
-                  ActivationFunctionType activation_type) {
+                  ActivationFunctionType activation_type,
+                  bool allow_fp32_relax_to_fp16 = false) {
     input1_ = AddInput(input1);
     input2_ = AddInput(input2);
     output_ = AddOutput(output);
     SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
                  CreateAddOptions(builder_, activation_type).Union());
-    BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+    BuildInterpreter({GetShape(input1_), GetShape(input2_)},
+                     allow_fp32_relax_to_fp16);
   }
 
   int input1() { return input1_; }
@@ -71,6 +73,19 @@
   EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
 }
 
+// Do a test with the NN API using no activation.
+// The test allows computing FP32 with FP16 precision. In this particular case,
+// calculating in FP32 or FP16 should produce the same results.
+TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
+  FloatAddOpModel m(
+      {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
+      {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
+  m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
+  m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 1.0, 4.0, 6.0}));
+}
+
 // Do a test with the NN api with relu.
 TEST(NNAPIDelegate, AddWithRelu) {
   FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index 3c5f805..5c20eed 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,43 +12,11 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// Compatibility shim for moved header location.
 #ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
 #define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
 
-#include <cstdarg>
-#include "tensorflow/contrib/lite/context.h"
-
-namespace tflite {
-
-// A functor that reports error to supporting system. Invoked similar to
-// printf.
-//
-// Usage:
-//  ErrorReporter foo;
-//  foo.Report("test %d", 5);
-// or
-//  va_list args;
-//  foo.Report("test %d", args); // where args is va_list
-//
-// Subclass ErrorReporter to provide another reporting destination.
-// For example, if you have a GUI program, you might redirect to a buffer
-// that drives a GUI error log box.
-class ErrorReporter {
- public:
-  virtual ~ErrorReporter();
-  virtual int Report(const char* format, va_list args) = 0;
-  int Report(const char* format, ...);
-  int ReportError(void*, const char* format, ...);
-};
-
-// An error reporter that simplify writes the message to stderr.
-struct StderrReporter : public ErrorReporter {
-  int Report(const char* format, va_list args) override;
-};
-
-// Return the default error reporter (output to stderr).
-ErrorReporter* DefaultErrorReporter();
-
-}  // namespace tflite
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
 
 #endif  // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 8fc07e8..ea4a543 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -78,6 +78,7 @@
     data = ["//tensorflow/contrib/lite:testdata/add.bin"],
     deps = [
         ":c_api",
+        "//tensorflow/contrib/lite:context",
         "//tensorflow/contrib/lite:kernel_api",
         "//tensorflow/contrib/lite/testing:util",
         "@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index a4ab0e8..c589cf7 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 #include "tensorflow/contrib/lite/experimental/c/c_api.h"
 
+#include <memory>
+
 #include "tensorflow/contrib/lite/context.h"
 #include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/interpreter.h"
@@ -29,12 +31,14 @@
 TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
   auto model = tflite::FlatBufferModel::BuildFromBuffer(
       static_cast<const char*>(model_data), model_size);
-  return model ? new TFL_Model{std::move(model)} : nullptr;
+  std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+  return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
 }
 
 TFL_Model* TFL_NewModelFromFile(const char* model_path) {
   auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
-  return model ? new TFL_Model{std::move(model)} : nullptr;
+  std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+  return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
 }
 
 void TFL_DeleteModel(TFL_Model* model) { delete model; }
@@ -72,7 +76,7 @@
     }
   }
 
-  return new TFL_Interpreter{std::move(interpreter)};
+  return new TFL_Interpreter{model->impl, std::move(interpreter)};
 }
 
 void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -129,6 +133,8 @@
   return static_cast<void*>(tensor->data.raw);
 }
 
+const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; }
+
 TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
                                     size_t input_data_size) {
   if (tensor->bytes != input_data_size) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index 3757349..b429e76 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -93,7 +93,8 @@
 // failure.
 //
 // * `model` must be a valid model instance. The caller retains ownership of the
-//   object, and can destroy it immediately after creating the interpreter.
+//   object, and can destroy it immediately after creating the interpreter; the
+//   interpreter will maintain its own reference to the underlying model data.
 // * `optional_options` may be null. The caller retains ownership of the object,
 //   and can safely destroy it immediately after creating the interpreter.
 //
@@ -145,6 +146,11 @@
 
 // Returns the tensor associated with the output index.
 // REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor)
+//
+// NOTE: The shape and underlying data buffer for output tensors may be not
+// be available until after the output tensor has been both sized and allocated.
+// In general, best practice is to interact with the output tensor *after*
+// calling TFL_InterpreterInvoke().
 TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor(
     const TFL_Interpreter* interpreter, int32_t output_index);
 
@@ -172,12 +178,15 @@
 
 // Returns a pointer to the underlying data buffer.
 //
-// Note: The result may be null if tensors have not yet been allocated, e.g.,
+// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
 // if the Tensor has just been created or resized and `TFL_AllocateTensors()`
 // has yet to be called, or if the output tensor is dynamically sized and the
 // interpreter hasn't been invoked.
 TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor);
 
+// Returns the (null-terminated) name of the tensor.
+TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor);
+
 // Copies from the provided input buffer into the tensor's buffer.
 // REQUIRES: input_data_size == TFL_TensorByteSize(tensor)
 TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer(
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index c5c612a..60c2e4e 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -24,7 +24,8 @@
 // not be depended on.
 
 struct TFL_Model {
-  std::unique_ptr<tflite::FlatBufferModel> impl;
+  // Sharing is safe as FlatBufferModel is const.
+  std::shared_ptr<const tflite::FlatBufferModel> impl;
 };
 
 struct TFL_InterpreterOptions {
@@ -35,6 +36,9 @@
 };
 
 struct TFL_Interpreter {
+  // Taking a reference to the (const) model data avoids lifetime-related issues
+  // and complexity with the TFL_Model's existence.
+  std::shared_ptr<const tflite::FlatBufferModel> model;
   std::unique_ptr<tflite::Interpreter> impl;
 };
 
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index a631dae..649dac8 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -55,6 +55,8 @@
   EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1);
   EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2);
   EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2);
+  EXPECT_NE(TFL_TensorData(input_tensor), nullptr);
+  EXPECT_STREQ(TFL_TensorName(input_tensor), "input");
 
   std::array<float, 2> input = {1.f, 3.f};
   ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(),
@@ -70,6 +72,8 @@
   EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1);
   EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2);
   EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2);
+  EXPECT_NE(TFL_TensorData(output_tensor), nullptr);
+  EXPECT_STREQ(TFL_TensorName(output_tensor), "output");
 
   std::array<float, 2> output;
   ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(),
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
index 9c06c4e..4786cc6 100644
--- a/tensorflow/contrib/lite/experimental/kernels/BUILD
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -53,6 +53,7 @@
         "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:string_util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:builtin_ops",
         "//tensorflow/contrib/lite/kernels:gemm_support",
         "//tensorflow/contrib/lite/kernels:kernel_util",
@@ -61,8 +62,8 @@
         "//tensorflow/contrib/lite/kernels/internal:optimized",
         "//tensorflow/contrib/lite/kernels/internal:optimized_base",
         "//tensorflow/contrib/lite/kernels/internal:quantization_util",
-        "//tensorflow/contrib/lite/kernels/internal:reference",
         "//tensorflow/contrib/lite/kernels/internal:reference_base",
+        "//tensorflow/contrib/lite/kernels/internal:tensor",
         "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
         "@flatbuffers",
     ],
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
index c658e43..7c50992 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -257,6 +257,16 @@
   } else {
     max_coeff = raw_input.maxCoeff();
   }
+
+  // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+  float logsumexp = 0.0;
+  for (int j = 0; j < raw_input.size(); ++j) {
+    logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+  }
+  logsumexp = Eigen::numext::log(logsumexp);
+  // Final normalization offset to get correct log probabilities.
+  float norm_offset = max_coeff + logsumexp;
+
   const float label_selection_input_min =
       (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
                                      : -std::numeric_limits<float>::infinity();
@@ -288,10 +298,10 @@
                       beam_scorer_->GetStateExpansionScore(b->state, previous));
       }
       // Plabel(l=abc @ t=6) *= P(c @ 6)
-      b->newp.label += raw_input(b->label) - max_coeff;
+      b->newp.label += raw_input(b->label) - norm_offset;
     }
     // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
-    b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+    b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
     // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
     b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
 
@@ -326,6 +336,8 @@
       const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
       // Perform label selection: if input for this label looks very
       // unpromising, never evaluate it with a scorer.
+      // We may compare logits instead of log probabilities,
+      // since the difference is the same in both cases.
       if (logit < label_selection_input_min) {
         continue;
       }
@@ -339,7 +351,7 @@
         //   Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
         beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
         float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
-        c.newp.label = logit - max_coeff +
+        c.newp.label = logit - norm_offset +
                        beam_scorer_->GetStateExpansionScore(c.state, previous);
         // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
         c.newp.total = c.newp.label;
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 121997d..8442c4d 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 #include <vector>
 #include "flatbuffers/flexbuffers.h"  // flatbuffers
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index 3245830..aa42b49 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -117,7 +117,7 @@
   EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
   // Check log probabilities output.
   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
-              ElementsAreArray(ArrayFloatNear({0.32134813})));
+              ElementsAreArray(ArrayFloatNear({-0.357094})));
 }
 
 TEST(CTCBeamSearchTest, MultiBatchTest) {
@@ -148,9 +148,8 @@
   EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
   EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
   // Check log probabilities output.
-  EXPECT_THAT(
-      m.GetLogProbabilitiesOutput(),
-      ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+  EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+              ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
 }
 
 TEST(CTCBeamSearchTest, MultiPathsTest) {
@@ -188,8 +187,8 @@
   EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
   // Check log probabilities output.
   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
-              ElementsAreArray(ArrayFloatNear(
-                  {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+              ElementsAreArray(
+                  ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
 }
 
 TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
@@ -223,7 +222,7 @@
   EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
   // Check log probabilities output.
   EXPECT_THAT(m.GetLogProbabilitiesOutput(),
-              ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+              ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
 }
 
 }  // namespace
diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD
new file mode 100644
index 0000000..82d39c0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = [
+    "//visibility:public",
+])
+
+licenses(["notice"])  # Apache 2.0
+
+cc_binary(
+    name = "option_writer_generator",
+    srcs = ["option_writer_generator.cc"],
+    deps = [
+        "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+        "@flatbuffers",
+    ],
+)
+
+cc_library(
+    name = "writer_lib",
+    srcs = [
+        "enum_mapping.h",
+        "writer_lib.cc",
+    ],
+    hdrs = [
+        "writer_lib.h",
+    ],
+    data = [
+        ":option_writer_gen",
+    ],
+    textual_hdrs = ["option_writer_generated.h"],
+    deps = [
+        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite:schema_fbs_version",
+        "//tensorflow/contrib/lite/kernels:builtin_ops",
+        "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+    ],
+)
+
+cc_binary(
+    name = "writer",
+    srcs = ["writer.cc"],
+    deps = [
+        ":writer_lib",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:builtin_ops",
+    ],
+)
+
+cc_test(
+    name = "writer_lib_test",
+    size = "small",
+    srcs = ["writer_lib_test.cc"],
+    deps = [
+        ":writer_lib",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:builtin_ops",
+        "//tensorflow/contrib/lite/testing:util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+genrule(
+    name = "option_writer_gen",
+    outs = ["option_writer_generated.h"],
+    cmd = "$(location :option_writer_generator) $(@)",
+    tools = [":option_writer_generator"],
+)
diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
new file mode 100644
index 0000000..8bc464f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+// TODO(aselle): Ideally extract this from the schema.
+
+namespace tflite {
+
+inline ActivationFunctionType TfLiteActivationToSchemaActivation(
+    TfLiteFusedActivation act) {
+  switch (act) {
+    case kTfLiteActNone:
+      return ActivationFunctionType_NONE;
+    case kTfLiteActRelu:
+      return ActivationFunctionType_RELU;
+    case kTfLiteActRelu1:
+      return ActivationFunctionType_RELU_N1_TO_1;
+    case kTfLiteActRelu6:
+      return ActivationFunctionType_RELU6;
+    case kTfLiteActTanh:
+      return ActivationFunctionType_TANH;
+    case kTfLiteActSignBit:
+      return ActivationFunctionType_SIGN_BIT;
+    case kTfLiteActSigmoid:
+      return ActivationFunctionType_NONE;  // TODO(aselle): Add to schema
+  }
+  return ActivationFunctionType_NONE;
+}
+
+inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
+  switch (padding) {
+    case kTfLitePaddingUnknown:
+      return Padding_SAME;  // TODO(aselle): Consider an error.
+    case kTfLitePaddingSame:
+      return Padding_SAME;
+    case kTfLitePaddingValid:
+      return Padding_VALID;
+  }
+  return Padding_SAME;  // TODO(aselle): Consider an error.
+}
+
+inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
+  switch (type) {
+    // case kTfLiteNoType: return TensorType_NONE;
+    case kTfLiteNoType:
+      return TensorType_FLOAT32;  // TODO(aselle): Consider an error.
+    case kTfLiteFloat32:
+      return TensorType_FLOAT32;
+    case kTfLiteInt32:
+      return TensorType_INT32;
+    case kTfLiteUInt8:
+      return TensorType_UINT8;
+    case kTfLiteInt64:
+      return TensorType_INT64;
+    case kTfLiteString:
+      return TensorType_STRING;
+    case kTfLiteBool:
+      return TensorType_BOOL;
+    case kTfLiteInt16:
+      return TensorType_INT16;
+    case kTfLiteComplex64:
+      return TensorType_COMPLEX64;
+  }
+  // TODO(aselle): consider an error
+}
+
+inline FullyConnectedOptionsWeightsFormat
+FullyConnectedOptionsWeightsFormatToSchema(
+    TfLiteFullyConnectedWeightsFormat format) {
+  switch (format) {
+    case kTfLiteFullyConnectedWeightsFormatDefault:
+      return FullyConnectedOptionsWeightsFormat_DEFAULT;
+    case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
+      return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+  }
+}
+
+inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
+  switch (type) {
+    case kTfLiteLSTMFullKernel:
+      return LSTMKernelType_FULL;
+    case kTfLiteLSTMBasicKernel:
+      return LSTMKernelType_BASIC;
+  }
+}
+
+inline LSHProjectionType LSHProjectionTypeToSchema(
+    TfLiteLSHProjectionType type) {
+  switch (type) {
+    case kTfLiteLshProjectionUnknown:
+      return LSHProjectionType_UNKNOWN;
+    case kTfLiteLshProjectionSparse:
+      return LSHProjectionType_SPARSE;
+    case kTfLiteLshProjectionDense:
+      return LSHProjectionType_DENSE;
+  }
+}
+
+}  // namespace tflite
+#endif  // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000..e6d5a77
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h"  // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+//  cat  third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+                                      "TfLitePoolParams",
+                                      "TfLiteDepthwiseConvParams",
+                                      "TfLiteSVDFParams",
+                                      "TfLiteRNNParams",
+                                      "TfLiteSequenceRNNParams",
+                                      "TfLiteFullyConnectedParams",
+                                      "TfLiteLSHProjectionParams",
+                                      "TfLiteSoftmaxParams",
+                                      "TfLiteConcatenationParams",
+                                      "TfLiteAddParams",
+                                      "TfLiteSpaceToBatchNDParams",
+                                      "TfLiteBatchToSpaceNDParams",
+                                      "TfLiteMulParams",
+                                      "TfLiteSubParams",
+                                      "TfLiteDivParams",
+                                      "TfLiteL2NormParams",
+                                      "TfLiteLocalResponseNormParams",
+                                      "TfLiteLSTMParams",
+                                      "TfLiteResizeBilinearParams",
+                                      "TfLitePadParams",
+                                      "TfLitePadV2Params",
+                                      "TfLiteReshapeParams",
+                                      "TfLiteSkipGramParams",
+                                      "TfLiteSpaceToDepthParams",
+                                      "TfLiteCastParams",
+                                      "TfLiteEmbeddingLookupSparseParams",
+                                      "TfLiteGatherParams",
+                                      "TfLiteTransposeParams",
+                                      "TfLiteReducerParams",
+                                      "TfLiteSplitParams",
+                                      "TfLiteSqueezeParams",
+                                      "TfLiteStridedSliceParams",
+                                      "TfLiteArgMaxParams",
+                                      "TfLiteArgMinParams",
+                                      "TfLiteTransposeConvParams",
+                                      "TfLiteSparseToDenseParams",
+                                      "TfLiteShapeParams",
+                                      "TfLiteFakeQuantParams",
+                                      "TfLitePackParams",
+                                      "TfLiteOneHotParams",
+                                      nullptr};
+}  // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+  const char* s = in.c_str();
+  bool first = true;
+  std::string out;
+  while (*s != '\0') {
+    if (*s == '_') {
+      first = true;
+    } else if (first) {
+      out.push_back(tolower(*s));
+      first = false;
+    } else {
+      out.push_back(tolower(*s));
+    }
+    s++;
+  }
+  return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+  OpOptionData() {
+    BuildOpList();
+    BuildOptionToTypeFunctionMap();
+    BuildOpToOptionMap();
+  }
+
+  // A list of builtin operations
+  const std::vector<std::string>& ops() const { return ops_; }
+  // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+  const std::unordered_map<std::string, std::string>& op_to_option() {
+    return op_to_option_;
+  }
+  // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+  const std::unordered_map<std::string, std::string>& option_to_struct() {
+    return option_to_struct_;
+  }
+  // Maps from option to a flatbuffer type function that describes that option.
+  const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+  option_to_type_function() {
+    return option_to_type_function_;
+  }
+
+ private:
+  void BuildOpList() {
+    for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+         ++curr) {
+      if (strlen(*curr) != 0) ops_.push_back(*curr);
+    }
+  }
+
+  void BuildOptionToTypeFunctionMap() {
+    auto d = tflite::BuiltinOptionsTypeTable();
+    for (int i = 0; i < d->num_elems; i++) {
+      flatbuffers::TypeCode code = d->type_codes[i];
+      if (code.sequence_ref != -1) {
+        option_to_type_function_.insert(
+            std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+      }
+    }
+  }
+
+  void BuildOpToOptionMap() {
+    // Manually specified mappings between ops and options
+    op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+    op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+    op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+    op_to_option_["UNPACK"] = "";
+    op_to_option_["SUM"] = "ReducerOptions";
+    op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+    op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+    op_to_option_["MEAN"] = "ReducerOptions";
+    op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+    op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+    op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+    op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+    op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+    op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+    op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+    op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+    op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+    // Manually specified mappings between ops and options (none)
+    op_to_option_["EMBEDDING_LOOKUP"] =
+        "";  // TODO(aselle): maybe something else.
+    op_to_option_["FLOOR"] = "";
+    op_to_option_["HASHTABLE_LOOKUP"] =
+        "";  // TODO(aselle): maybe something else.
+    op_to_option_["LOGISTIC"] = "";
+    op_to_option_["RELU"] = "";
+    op_to_option_["RELU_N1_TO_1"] = "";
+    op_to_option_["RELU6"] = "";
+    op_to_option_["TANH"] = "";
+    op_to_option_["CUSTOM"] = "";    // TODO(aselle): maybe something else.
+    op_to_option_["DELEGATE"] = "";  // TODO(aselle): maybe something else.
+    op_to_option_["PRELU"] = "";
+    op_to_option_["MAXIMUM"] = "";  // TODO(aselle): MaximumMinimumOptions
+    op_to_option_["MINIMUM"] = "";  // TODO(aselle): MaximumMinimumOptions
+    op_to_option_["SIN"] = "";
+    op_to_option_["LOG"] = "";
+    op_to_option_["SQRT"] = "";
+    op_to_option_["RSQRT"] = "";
+
+    // TODO(aselle): These are undesirable hacks. Consider changing C structs
+    option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+    option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+    option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+    option_to_struct_["LocalResponseNormalizationOptions"] =
+        "TfLiteLocalResponseNormParams";
+    // Now for every op, try to find an option.
+    bool fatal = false;
+    for (auto op_name : ops_) {
+      bool found_option = false;
+      auto d = tflite::BuiltinOptionsTypeTable();
+      std::string collapsed_option_name_guess =
+          ToCollapsed(op_name) + "options";
+      // O(n^2) but not that big of n.
+      for (int i = 0; i < d->num_elems; i++) {
+        std::string option_name = d->names[i];
+        std::string collapsed_option_name = ToCollapsed(option_name);
+        if (collapsed_option_name_guess == collapsed_option_name) {
+          op_to_option_.insert(std::make_pair(op_name, option_name));
+          found_option = true;
+          break;
+        }
+      }
+      auto it = op_to_option_.find(op_name);
+      if (it == op_to_option_.end()) {
+        std::cerr << "Didn't find option for  " << op_name << std::endl;
+        fatal = true;
+      } else if (!it->second.empty()) {
+        std::string option_name = it->second;
+
+        if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+          bool param_struct_found = false;
+          std::string params_guess = std::string("TfLite") + option_name;
+          size_t start = params_guess.find("Options");
+          size_t len = strlen("Options");
+          params_guess.replace(start, len, "Params");
+          for (auto* param = param_structs; *param != nullptr; param++) {
+            if (*param == params_guess) {
+              param_struct_found = true;
+              break;
+            }
+          }
+          if (!param_struct_found) {
+            std::cerr << "Failed to get param struct for option " << option_name
+                      << std::endl;
+            fatal = true;
+          } else {
+            option_to_struct_.insert(std::make_pair(option_name, params_guess));
+          }
+        }
+      }
+    }
+  }
+
+ private:
+  std::vector<std::string> ops_;
+  std::unordered_map<std::string, std::string> op_to_option_;
+  std::unordered_map<std::string, std::string> option_to_struct_;
+  std::unordered_map<std::string, flatbuffers::TypeFunction>
+      option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+                         const std::string& option_name,
+                         const std::string& option_type,
+                         const flatbuffers::TypeTable* options,
+                         const std::string& struct_name) {
+  // Skip tricky ones for now
+  if (struct_name == "TfLiteResizeBilinearParams") return;
+  if (struct_name == "TfLiteSqueezeParams") return;
+  if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+  if (struct_name == "TfLiteReshapeParams") return;
+
+  fprintf(fp, "  case BuiltinOperator_%s:  {\n", op_name.c_str());
+  fprintf(fp,
+          "    const auto* params = reinterpret_cast<const "
+          "%s*>(builtin_op_data);\n",
+          struct_name.c_str());
+
+  for (size_t i = 0; i < options->num_elems; i++) {
+    std::string elem_name = options->names[i];
+    // TODO(aselle): Irregular naming in builtins
+    if (elem_name == "fused_activation_function")
+      elem_name = "activation";
+    else if (elem_name == "stride_w")
+      elem_name = "stride_width";
+    else if (elem_name == "stride_h")
+      elem_name = "stride_height";
+    else if (elem_name == "dilation_h_factor")
+      elem_name = "dilation_height_factor";
+    else if (elem_name == "dilation_w_factor")
+      elem_name = "dilation_width_factor";
+    else if (elem_name == "new_shape")
+      elem_name = "shape";
+
+    flatbuffers::TypeCode code = options->type_codes[i];
+    auto contained_type = code.sequence_ref != -1
+                              ? options->type_refs[code.sequence_ref]
+                              : nullptr;
+    std::string mapper = "";
+    if (contained_type == TensorTypeTypeTable) {
+      mapper = "TfLiteTypeToSchemaType";
+    } else if (contained_type == ActivationFunctionTypeTypeTable) {
+      mapper = "TfLiteActivationToSchemaActivation";
+    } else if (contained_type == PaddingTypeTable) {
+      mapper = "TfLitePaddingToSchemaPadding";
+    } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+      mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+    } else if (contained_type == LSTMKernelTypeTypeTable) {
+      mapper = "LSTMKernelTypeToSchema";
+    } else if (contained_type == LSHProjectionTypeTypeTable) {
+      mapper = "LSHProjectionTypeToSchema";
+    }
+
+    fprintf(fp,
+            "    auto val%zu = "
+            "%s(params->%s);\n",
+            i, mapper.c_str(), elem_name.c_str());
+  }
+  fprintf(fp, "    auto union_type = Create%s(*fbb", option_name.c_str());
+  for (size_t i = 0; i < options->num_elems; i++) {
+    fprintf(fp, ", val%zu", i);
+  }
+  fprintf(fp, ").Union();\n");
+  fprintf(fp, "    return std::make_pair(%s, union_type);\n",
+          option_type.c_str());
+  fprintf(fp, "  }\n  break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+  std::unordered_set<std::string> ignores;
+  ignores.insert("CONCAT_EMBEDDINGS");
+  ignores.insert("CALL");
+
+  // Allow any op that doesn't have an options struct to be blocked
+  // together
+  for (const auto& op_name : option->ops()) {
+    auto option_it = option->op_to_option().find(op_name);
+    if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+      continue;
+    fprintf(fp, "  case BuiltinOperator_%s:\n", op_name.c_str());
+  }
+  fprintf(fp,
+          "    return std::make_pair(BuiltinOptions_NONE, "
+          "flatbuffers::Offset<void>());\n    break;\n");
+
+  // Iterate over each ops
+  for (const auto& op_name : option->ops()) {
+    if (ignores.find(op_name) != ignores.end()) continue;
+    // Get to the option and struct names, continuing if not found.
+    auto option_it = option->op_to_option().find(op_name);
+    if (option_it->second.empty()) continue;
+    std::string option_name = option_it->second;
+    std::string option_type = "BuiltinOptions_" + option_name;
+    auto option_func_it = option->option_to_type_function().find(option_name);
+    if (option_func_it == option->option_to_type_function().end()) continue;
+    auto struct_name_it = option->option_to_struct().find(option_name);
+    if (struct_name_it == option->option_to_struct().end()) {
+      // If no C struct, then it better have no arguments.
+      auto type_info = option_func_it->second();
+      if (type_info->num_elems != 0) {
+        // We have non-zero arguments in the schema, this means there
+        // should be a struct.
+        fprintf(stderr,
+                "Op %s uses option struct %s which has no builtin struct\n",
+                op_name.c_str(), option_name.c_str());
+        exit(1);
+      }
+      fprintf(fp, "  case BuiltinOperator_%s:\n", op_name.c_str());
+      fprintf(fp, "    return std::make_pair(%s, Create%s(*fbb).Union());",
+              option_type.c_str(), option_name.c_str());
+    } else {
+      // If C struct, then we need to assign all properties
+      auto struct_name = struct_name_it->second;
+      GenerateImportForOp(fp, op_name, option_name, option_type,
+                          option_func_it->second(), struct_name);
+    }
+  }
+  // TODO(aselle): Handle unhandled cases more gracefully.
+  fprintf(fp,
+          "default:    return std::make_pair(BuiltinOptions_NONE, "
+          "flatbuffers::Offset<void>());\n    break;\n");
+}
+
+}  // namespace tflite
+
+int main(int argc, char* argv[]) {
+  tflite::OpOptionData option;
+  if (argc != 2) {
+    fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+    return 1;
+  }
+  FILE* fp = fopen(argv[1], "w");
+  tflite::GenerateImport(&option, fp);
+  fclose(fp);
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc
new file mode 100644
index 0000000..20ede21
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Just does a read/write loop of tflite file format using the interpreter as
+// an intermediate.
+//
+// Usage:
+//   writer <input tflite> <output tflite>
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+int main(int argc, char* argv[]) {
+  if (argc != 3) {
+    fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
+    return 1;
+  }
+  std::unique_ptr<tflite::FlatBufferModel> model =
+      tflite::FlatBufferModel::BuildFromFile(argv[1]);
+  std::unique_ptr<tflite::Interpreter> interpreter;
+  tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
+  tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
+  tflite::InterpreterWriter writer(interpreter.get());
+  writer.Write(argv[2]);
+
+  return 0;
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
new file mode 100644
index 0000000..555a9cc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -0,0 +1,287 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <cstdlib>
+#include <cstring>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+template <class T>
+using Offset = flatbuffers::Offset<T>;
+template <class T>
+using Vector = flatbuffers::Vector<T>;
+using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
+
+std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
+    FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
+  switch (op) {
+#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
+  }
+  return std::make_pair(BuiltinOptions_NONE, Offset<void>());
+}
+
+template <class T_OUTPUT, class T_INPUT>
+Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
+                                                         const T_INPUT& v) {
+  std::vector<T_OUTPUT> inputs(v.begin(), v.end());
+  return fbb->template CreateVector<T_OUTPUT>(inputs);
+}
+
+Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
+    FlatBufferBuilder* fbb) {
+  std::vector<Offset<Operator>> operators;
+
+  std::vector<int> operator_to_opcode;
+  // TODO(aselle): Augment this once we put execution plan in schema.
+  operator_to_opcode.resize(interpreter_->nodes_size(), -1);
+  for (int op_index : interpreter_->execution_plan()) {
+    const auto* node_and_registration =
+        interpreter_->node_and_registration(op_index);
+    const TfLiteRegistration* registration = &node_and_registration->second;
+    if (!registration->custom_name) {
+      operator_to_opcode[op_index] =
+          GetOpCodeForBuiltin(registration->builtin_code);
+    } else {
+      operator_to_opcode[op_index] =
+          GetOpCodeForCustom(registration->custom_name);
+    }
+  }
+  // second pass serialize operators
+  for (int op_index : interpreter_->execution_plan()) {
+    const auto* node_and_registration =
+        interpreter_->node_and_registration(op_index);
+    const TfLiteNode& node = node_and_registration->first;
+    const TfLiteRegistration& registration = node_and_registration->second;
+    Offset<void> builtin_options;
+    BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
+    // Custom data
+    // TODO(aselle): Custom options format is not known by default. Just assume
+    // for now.
+    auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
+    Offset<Vector<uint8_t>> custom_options = 0;
+
+    if (!registration.custom_name) {
+      // builtin
+      auto builtin_options_and_type = CreateBuiltinUnion(
+          fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
+          node.builtin_data);
+      builtin_options = builtin_options_and_type.second;
+      builtin_options_type = builtin_options_and_type.first;
+    } else {
+      auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
+      if (custom_writer != custom_op_to_writer_.end() &&
+          custom_writer->second) {
+        // delegate to custom writer if it exists
+        custom_writer->second(fbb, interpreter_, op_index, &custom_options,
+                              &custom_options_format);
+      } else {
+        // use the custom data as fact
+        custom_options = fbb->CreateVector(
+            reinterpret_cast<const uint8_t*>(node.custom_initial_data),
+            node.custom_initial_data_size);
+      }
+    }
+
+    int opcode_index = operator_to_opcode[op_index];
+    std::vector<int> written_inputs =
+        RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
+    std::vector<int> written_outputs =
+        RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
+    auto inputs = ExportVector<int32_t>(fbb, written_inputs);
+    auto outputs = ExportVector<int32_t>(fbb, written_outputs);
+    operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
+                                       builtin_options_type, builtin_options,
+                                       custom_options, custom_options_format));
+  }
+
+  return fbb->template CreateVector<Offset<Operator>>(operators);
+}
+
+Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
+    FlatBufferBuilder* fbb) {
+  // Initialized to -1.
+  // A value of -1 means this tensor will not be exported.
+  tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
+
+  std::vector<Offset<Tensor>> tensors;
+
+  // Make a map from tensor index to whether the tensor is a temporary.
+  std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
+  for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
+    const auto* node_and_registration =
+        interpreter_->node_and_registration(op_index);
+    for (auto tensor_index :
+         TfLiteIntArrayView(node_and_registration->first.temporaries))
+      tensor_is_temporary[tensor_index] = true;
+  }
+
+  // Now we need to remap all used tensor indices
+  int curr_output_index = 0;
+  for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+       tensor_index++) {
+    // Temporary tensors and unused tensors will not be written.
+    if (!tensor_is_temporary[tensor_index] &&
+        unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
+      tensor_to_written_tensor_[tensor_index] = curr_output_index++;
+    }
+  }
+
+  for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+       ++tensor_index) {
+    // Tensor not exported.
+    if (tensor_to_written_tensor_[tensor_index] == -1) continue;
+
+    if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
+      // We only need to convert non temporaries
+      if (tensor->allocation_type != kTfLiteArenaRw &&
+          tensor->allocation_type != kTfLiteMmapRo &&
+          tensor->allocation_type != kTfLiteArenaRwPersistent)
+        continue;
+      // Allocate a buffer index
+      int buffer_index = 0;  // This is null
+      if (tensor->allocation_type == kTfLiteMmapRo) {
+        buffer_index = buffers_.size();
+        buffers_.push_back(std::make_pair(
+            reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
+      }
+      // Primitive type.
+      TensorType type = TfLiteTypeToSchemaType(tensor->type);
+      // Handle quantization
+      const Offset<Vector<float>> null_array;
+      Offset<Vector<float>> scale_array;
+      Offset<Vector<int64_t>> zero_point_array;
+      if (tensor->params.scale != 0.f) {
+        // We have quantization, make a single arugment array (multi channel
+        // quant needs updating here).
+        scale_array = fbb->CreateVector<float>({tensor->params.scale});
+        zero_point_array =
+            fbb->CreateVector<int64_t>({tensor->params.zero_point});
+      }
+      Offset<QuantizationParameters> quantization_params =
+          CreateQuantizationParameters(*fbb, null_array, null_array,
+                                       scale_array, zero_point_array);
+      // Shape
+      TfLiteIntArrayView shape_view(tensor->dims);
+      std::vector<int> shape =
+          std::vector<int>(shape_view.begin(), shape_view.end());
+
+      tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
+                                     type, buffer_index,
+                                     fbb->CreateString(tensor->name),
+                                     quantization_params, tensor->is_variable));
+    }
+  }
+  return fbb->template CreateVector<Offset<Tensor>>(tensors);
+}
+
+Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
+    FlatBufferBuilder* fbb) {
+  std::vector<Offset<Buffer>> buffer_vector;
+  for (auto buffer : buffers_) {
+    auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+    buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+  }
+  return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
+}
+
+Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
+    FlatBufferBuilder* fbb) {
+  std::vector<Offset<OperatorCode>> codes;
+  for (auto it : opcodes_) {
+    const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+    codes.push_back(CreateOperatorCodeDirect(
+        *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+  }
+  return fbb->template CreateVector<Offset<OperatorCode>>(codes);
+}
+
+template <class T>
+std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
+    const T& input) {
+  std::vector<int> output;
+  output.reserve(input.size());
+  for (int x : input) {
+    if (tensor_to_written_tensor_[x] != -1) {
+      output.push_back(tensor_to_written_tensor_[x]);
+    }
+  }
+  return output;
+}
+
+TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+                                          size_t* size) {
+  if (!out || !size) return kTfLiteError;
+  FlatBufferBuilder builder(/*initial_size=*/10240);
+
+  std::vector<Offset<SubGraph>> subgraphs_as_vector;
+  {  // subgraph specific stuff
+    auto tensors = ExportTensors(&builder);
+    std::vector<int> written_inputs =
+        RemapTensorIndicesToWritten(interpreter_->inputs());
+    std::vector<int> written_outputs =
+        RemapTensorIndicesToWritten(interpreter_->outputs());
+    auto inputs = ExportVector<int32_t>(&builder, written_inputs);
+    auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+
+    auto ops = ExportOperators(&builder);
+    subgraphs_as_vector.push_back(
+        CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
+  }
+  Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
+
+  auto description = builder.CreateString("Exported from Interpreter.");
+
+  auto op_codes = CreateOpCodeTable(&builder);
+  auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+                           builder.CreateVector(subgraphs_as_vector),
+                           description, buffers);
+  ::tflite::FinishModelBuffer(builder, model);
+  const uint8_t* buffer = builder.GetBufferPointer();
+  *size = builder.GetSize();
+  (*out).reset(new uint8_t[*size]);
+  memcpy(out->get(), buffer, *size);
+  return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
+  std::unique_ptr<uint8_t[]> buffer;
+  size_t size;
+  TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+
+  FILE* fp = fopen(filename.c_str(), "wb");
+  if (!fp) return kTfLiteError;
+
+  if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
+  if (fclose(fp)) return kTfLiteError;
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::RegisterCustomWriter(
+    const std::string& custom_name, CustomWriter custom_writer) {
+  if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
+    return kTfLiteError;
+  }
+  custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
+  return kTfLiteOk;
+}
+
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
new file mode 100644
index 0000000..a5f1469
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -0,0 +1,131 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
+//
+// Usage:
+//  From command line:
+//   bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
+//     -- foo.tflite foo.out.tflite
+//
+// From C++
+//   std::unique_ptr<Interpreter> interpreter;
+//   // Build Interpreter however
+//   // ... <omitted>
+//   InterpreterWriter(interpreter.get()).Write("output.tflite");
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#include <iostream>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
+// file format.
+class InterpreterWriter {
+ public:
+  typedef flatbuffers::Offset<Operator> (*CustomWriter)(
+      flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
+      int node_index,
+      flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
+      CustomOptionsFormat* custom_options_format);
+
+  // Construct an interpreter writer for the specified `interpreter`. Then,
+  // a uses .Write() or .GetBuffer(...)  to extract the data.
+  explicit InterpreterWriter(Interpreter* interpreter)
+      : interpreter_(interpreter) {
+    buffers_.push_back(std::make_pair(nullptr, 0));
+  }
+
+  // Get a buffer and size of a serialized flatbuffer.
+  TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+  // Write the serialized flatbuffer to the prescribed `filename`.
+  TfLiteStatus Write(const std::string& filename);
+  // Registers a custom writer for a custom op. The customization allows the
+  // caller to change the custom data.
+  TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
+                                    CustomWriter custom_writer);
+  // Tensors that are unused and shouldn't be written.
+  void SetUnusedTensors(const std::set<int>& unused_tensors) {
+    unused_tensors_ = unused_tensors;
+  }
+
+ private:
+  template <class T>
+  using Offset = flatbuffers::Offset<T>;
+  template <class T_OUTPUT, class T_INPUT>
+  Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
+      flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
+  Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
+      flatbuffers::FlatBufferBuilder* fbb);
+  Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
+      flatbuffers::FlatBufferBuilder* fbb);
+  Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+      flatbuffers::FlatBufferBuilder* fbb);
+  Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+      flatbuffers::FlatBufferBuilder* fbb);
+
+  template <class T>
+  std::vector<int> RemapTensorIndicesToWritten(const T& input);
+
+  int GetOpCodeForBuiltin(int builtin_op_index) {
+    // auto it = builtin_op_to_opcode_.find(builtin_op_index);
+    std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
+        builtin_op_to_opcode_.insert(
+            std::make_pair(builtin_op_index, opcodes_.size()));
+    if (result.second) {
+      opcodes_.push_back({builtin_op_index, ""});
+    }
+    return result.first->second;
+  }
+
+  int GetOpCodeForCustom(const std::string& custom_name) {
+    std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
+        custom_op_to_opcode_.insert(
+            std::make_pair(custom_name, opcodes_.size()));
+    if (result.second) {
+      opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+    }
+    return result.first->second;
+  }
+
+  // The interpreter we are writing
+  Interpreter* interpreter_;
+  // Keep track of byte buffers
+  std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+  // List of op codes and mappings from builtin or custom op to opcode
+  struct OpCode {
+    int builtin;
+    std::string custom;
+  };
+  std::set<int> unused_tensors_;
+  // For every tensor index in the interpreter, the index in the written.
+  // This is different due to temporary and unused tensors not being written.
+  std::vector<int> tensor_to_written_tensor_;
+  // List of used opcodes
+  std::vector<OpCode> opcodes_;
+  std::unordered_map<int, int> builtin_op_to_opcode_;
+  std::unordered_map<std::string, int> custom_op_to_opcode_;
+  std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
new file mode 100644
index 0000000..49194a7
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+// Make an interpreter that has no tensors and no nodes
+// TODO(b/113731921): add more tests.
+TEST(Writer, BasicTest) {
+  Interpreter interpreter;
+  interpreter.AddTensors(3);
+  float foo[] = {1, 2, 3};
+  interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+                                           TfLiteQuantizationParams());
+  interpreter.SetTensorParametersReadOnly(
+      1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
+      reinterpret_cast<char*>(foo), sizeof(foo));
+  interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+                                           TfLiteQuantizationParams());
+  interpreter.SetInputs({0, 1});
+  interpreter.SetOutputs({2});
+  const char* initial_data = "";
+  tflite::ops::builtin::BuiltinOpResolver resolver;
+  TfLiteAddParams* builtin_data =
+      reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+  builtin_data->activation = kTfLiteActNone;
+  const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+  interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+                                    reinterpret_cast<void*>(builtin_data), reg);
+
+  InterpreterWriter writer(&interpreter);
+  writer.Write("/tmp/test.tflite");
+  std::unique_ptr<FlatBufferModel> model =
+      FlatBufferModel::BuildFromFile("/tmp/test.tflite");
+  InterpreterBuilder builder(*model, resolver);
+  std::unique_ptr<Interpreter> new_interpreter;
+  builder(&new_interpreter);
+}
+
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md
deleted file mode 100644
index e3db478..0000000
--- a/tensorflow/contrib/lite/g3doc/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-This is a *work-in-progress* TF Lite subsite for:
-https://www.tensorflow.org/mobile
-
-DO NOT PUBLISH
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index 9119e49..b3f21e2 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -5,7 +5,8 @@
   rows:
   - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
     items:
-    - description: >
+    - classname: devsite-landing-row-50
+      description: >
         TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
         embedded devices. It enables on-device machine learning inference with
         low latency and a small binary size. TensorFlow Lite also supports
@@ -33,7 +34,7 @@
           icon_name: chevron_right
           foreground: theme
           background: grey
-    - code_block: |
+      code_block: |
         <pre class = "prettyprint">
         $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
                --input_format=TENSORFLOW_GRAPHDEF \
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
deleted file mode 100644
index 70031a3..0000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
+++ /dev/null
@@ -1,10 +0,0 @@
-Project: /mobile/_project.yaml
-Book: /mobile/_book.yaml
-page_type: reference
-<style> table img { max-width: 100%; } </style>
-<script src="/_static/js/managed/mathjax/MathJax.js?config=TeX-AMS-MML_SVG"></script>
-
-<!-- DO NOT EDIT! Automatically generated file. -->
-# All symbols in TensorFlow Lite
-
-TEMP PAGE
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index f255017..69616c7 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -37,7 +37,7 @@
 ```
 ### Data Alignment
 
-TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended
+TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended
 that all data provided to TensorFlow Lite be aligned that way.
 
 ### Error Reporting
@@ -112,7 +112,7 @@
 
   * Tensors are represented by integers, in order to avoid string comparisons
     (and any fixed dependency on string libraries).
-  * An interpreter must not be accessed from concurrent threads
+  * An interpreter must not be accessed from concurrent threads.
   * Memory allocation for input and output tensors must be triggered
     by calling AllocateTensors() right after resizing tensors.
 
@@ -169,7 +169,7 @@
 including all the tensors. The latter allows implementations to access their
 inputs and outputs.
 
-When the interpreter loads a model, it calls init() once for each node in the
+When the interpreter loads a model, it calls `init()` once for each node in the
 graph. A given `init()` will be called more than once if the op is used
 multiple times in the graph. For custom ops a configuration buffer will be
 provided, containing a flexbuffer that maps parameter names to their values.
@@ -210,8 +210,9 @@
 
 Note that registration is not automatic and an explicit call to
 `Register_MY_CUSTOM_OP` should be made somewhere. While the standard
-`:builtin_ops` takes care of the registration of builtins, custom ops will have
-to be collected in separated custom libraries.
+`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the
+registration of builtins, custom ops will have to be collected in separate
+custom libraries.
 
 ### Customizing the kernel library
 
@@ -232,7 +233,7 @@
 };
 ```
 
-The regular usage will require the developer to use the `BuiltinOpResolver` and
+Regular usage will require the developer to use the `BuiltinOpResolver` and
 write:
 
 ```c++
@@ -308,18 +309,25 @@
 
 #### Inputs
 
-Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of
-the supported primitive types.
+Each input should be an array or multi-dimensional array of the supported
+primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is
+an array or multi-dimensional array, the associated input tensor will be
+implicitly resized to the array's dimensions at inference time. If the input is
+a ByteBuffer, the caller should first manually resize the associated input
+tensor (via `Interpreter.resizeInput()`) before running inference.
 
-The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid
-unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its
-order must be `ByteOrder.nativeOrder()`. After it is used for a model inference,
-it must remain unchanged until the model inference is finished.
+When using 'ByteBuffer', prefer using direct byte buffers, as this allows the
+`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte
+buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a
+model inference, it must remain unchanged until the model inference is finished.
 
 #### Outputs
 
-Each output should be an array, or a multi-dimensional array of the supported
-primitive types.
+Each output should be an array or multi-dimensional array of the supported
+primitive types, or a ByteBuffer of the appropriate size. Note that some models
+have dynamic outputs, where the shape of output tensors can vary depending on
+the input. There's no straightforward way of handling this with the existing
+Java inference API, but planned extensions will make this possible.
 
 #### Running Model Inference
 
@@ -339,9 +347,10 @@
 where each entry in `inputs` corresponds to an input tensor and
 `map_of_indices_to_outputs` maps indices of output tensors to the
 corresponding output data. In both cases the tensor indices should correspond to
-the values given to the `TensorFlow Lite Optimized Converter` when the model was
-created. Be aware that the order of tensors in `input` must match the order
-given to the `TensorFlow Lite Optimized Converter`.
+the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md)
+when the model was created. Be aware that the order of tensors in `input` must
+match the order given to the `TensorFlow Lite Optimized Converter`.
+
 
 The Java API also provides convenient functions for app developers to get the
 index of any model input or output using a tensor name:
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 0f9d016..a4267ee 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -3,65 +3,68 @@
 
 ## Image classification (Float Models)
 
-Model Name          | Paper_Model_Files^                                                                                                                                                                        | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
-------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
-DenseNet            | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz)            | 43.6 Mb    | 64.2%          | 85.6%          | 894 ms                | 1262 ms
-SqueezeNet          | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz)          | 5.0 Mb     | 49.0%          | 72.9%          | 224 ms                | 255 ms
-NASNet mobile       | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz)       | 21.4 Mb    | 74.2%          | 91.7%          | 261 ms                | 389 ms
-NASNet large        | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz)        | 355.3 Mb   | 82.8%          | 96.2%          | 6697 ms               | 7940 ms
-ResNet_V2_50        | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz)        | 102.3 Mb   | 68.1%          | 88.4%          | 942 ms                | 1008 ms
-ResNet_V2_101       | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz)       | 178.3 Mb   | 70.4%          | 89.6%          | 1880 ms               | 1970 ms
-Inception_V3        | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz)         | 95.3 Mb    | 78.2%          | 94.0%          | 1433 ms               | 1522 ms
-Inception_V4        | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz)         | 170.7 Mb   | 80.4%          | 95.2%          | 2986 ms               | 3139 ms
-Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb   | 77.8%          | 94.1%          | 2731 ms               | 2926 ms
-Mobilenet_V1_0.25_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)                                       | 1.9 Mb     | 41.6%          | 66.6%          | 6.2 ms                | 13.0 ms
-Mobilenet_V1_0.25_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz)                                       | 1.9 Mb     | 45.7%          | 70.6%          | 8.6 ms                | 19.5 ms
-Mobilenet_V1_0.25_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz)                                       | 1.9 Mb     | 47.5%          | 72.4%          | 12.1 ms               | 27.8 ms
-Mobilenet_V1_0.25_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz)                                       | 1.9 Mb     | 50.0%          | 74.4%          | 16.2 ms               | 37.3 ms
-Mobilenet_V1_0.50_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz)                                        | 5.3 Mb     | 56.5%          | 79.5%          | 18.1 ms               | 29.9 ms
-Mobilenet_V1_0.50_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)                                        | 5.3 Mb     | 59.3%          | 82.1%          | 26.8 ms               | 45.9 ms
-Mobilenet_V1_0.50_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz)                                        | 5.3 Mb     | 62.0%          | 83.7%          | 35.6 ms               | 65.3 ms
-Mobilenet_V1_0.50_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz)                                        | 5.3 Mb     | 63.5%          | 85.0%          | 47.6 ms               | 164.2 ms
-Mobilenet_V1_0.75_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz)                                       | 10.3 Mb    | 62.3%          | 84.1%          | 34.6 ms               | 48.7 ms
-Mobilenet_V1_0.75_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz)                                       | 10.3 Mb    | 65.5%          | 86.1%          | 51.3 ms               | 75.2 ms
-Mobilenet_V1_0.75_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz)                                       | 10.3 Mb    | 67.4%          | 87.4%          | 71.7 ms               | 107.0 ms
-Mobilenet_V1_0.75_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz)                                       | 10.3 Mb    | 68.6%          | 88.3%          | 95.7 ms               | 143.4 ms
-Mobilenet_V1_1.0_128   | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz)                                        | 16.9 Mb    | 65.5%          | 85.9%          | 57.4 ms               | 76.8 ms
-Mobilenet_V1_1.0_160   | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz)                                        | 16.9 Mb    | 68.3%          | 87.8%          | 86.0 ms               | 117.7 ms
-Mobilenet_V1_1.0_192   | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz)                                        | 16.9 Mb    | 70.2%          | 89.3%          | 118.6 ms              | 167.3 ms
-Mobilenet_V1_1.0_224   | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)                                        | 16.9 Mb    | 71.3%          | 90.1%          | 160.1 ms              | 224.3 ms
+Model Name            | Paper_Model_Files^                                                                                                                                                                        | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
+--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
+DenseNet              | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz)            | 43.6 Mb    | 64.2%          | 85.6%          | 894 ms                | 1262 ms
+SqueezeNet            | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz)          | 5.0 Mb     | 49.0%          | 72.9%          | 224 ms                | 255 ms
+NASNet mobile         | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz)       | 21.4 Mb    | 73.9%          | 91.5%          | 261 ms                | 389 ms
+NASNet large          | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz)        | 355.3 Mb   | 82.6%          | 96.1%          | 6697 ms               | 7940 ms
+ResNet_V2_101         | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz)                                   | 178.3 Mb   | 76.8%          | 93.6%          | 1880 ms               | 1970 ms
+Inception_V3          | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz)         | 95.3 Mb    | 77.9%          | 93.8%          | 1433 ms               | 1522 ms
+Inception_V4          | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz)         | 170.7 Mb   | 80.1%          | 95.1%          | 2986 ms               | 3139 ms
+Inception_ResNet_V2   | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb   | 77.5%          | 94.0%          | 2731 ms               | 2926 ms
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)                                       | 1.9 Mb     | 41.4%          | 66.2%          | 6.2 ms                | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz)                                       | 1.9 Mb     | 45.4%          | 70.2%          | 8.6 ms                | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz)                                       | 1.9 Mb     | 47.1%          | 72.0%          | 12.1 ms               | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz)                                       | 1.9 Mb     | 49.7%          | 74.1%          | 16.2 ms               | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz)                                        | 5.3 Mb     | 56.2%          | 79.3%          | 18.1 ms               | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)                                        | 5.3 Mb     | 59.0%          | 81.8%          | 26.8 ms               | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz)                                        | 5.3 Mb     | 61.7%          | 83.5%          | 35.6 ms               | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz)                                        | 5.3 Mb     | 63.2%          | 84.9%          | 47.6 ms               | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz)                                       | 10.3 Mb    | 62.0%          | 83.8%          | 34.6 ms               | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz)                                       | 10.3 Mb    | 65.2%          | 85.9%          | 51.3 ms               | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz)                                       | 10.3 Mb    | 67.1%          | 87.2%          | 71.7 ms               | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz)                                       | 10.3 Mb    | 68.3%          | 88.1%          | 95.7 ms               | 143.4 ms
+Mobilenet_V1_1.0_128  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz)                                        | 16.9 Mb    | 65.2%          | 85.7%          | 57.4 ms               | 76.8 ms
+Mobilenet_V1_1.0_160  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz)                                        | 16.9 Mb    | 68.0%          | 87.7%          | 86.0 ms               | 117.7 ms
+Mobilenet_V1_1.0_192  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz)                                        | 16.9 Mb    | 69.9%          | 89.1%          | 118.6 ms              | 167.3 ms
+Mobilenet_V1_1.0_224  | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)                                        | 16.9 Mb    | 71.0%          | 89.9%          | 160.1 ms              | 224.3 ms
+Mobilenet_V2_1.0_224  | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz)                                                | 14.0 Mb    | 71.8%          | 90.6%          | 117 ms                |
 
 ^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
 
 ^^ The performance numbers are generated in the benchmark on Pixel-2 using
 single thread large core.
 
-^^ Accuracy numbers were computed using the [TFLite accuracy tool](../tools/accuracy/ilsvrc)
-after excluding blacklisted images.
+^^ Accuracy numbers were computed using the
+[TFLite accuracy tool](../tools/accuracy/ilsvrc) .
 
 ## Image classification (Quantized Models)
 
-Model Name               | Paper_Model_Files                                                                                                                                         | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb     | 39.8%          | 64.8%          | 3.7 ms
-Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb     | 43.0%          | 68.4%          | 5.5 ms
-Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb     | 46.0%          | 71.2%          | 7.9 ms
-Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb     | 48.5%          | 73.1%          | 10.4 ms
-Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz)  | 1.4 Mb     | 55.2%          | 78.4%          | 8.8 ms
-Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz)  | 1.4 Mb     | 57.5%          | 80.7%          | 13.0 ms
-Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz)  | 1.4 Mb     | 60.2%          | 82.3%          | 18.3 ms
-Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz)  | 1.4 Mb     | 61.5%          | 83.5%          | 24.7 ms
-Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb     | 56.2%          | 79.4%          | 16.2 ms
-Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb     | 62.7%          | 83.9%          | 24.3 ms
-Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb     | 66.4%          | 86.4%          | 33.8 ms
-Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb     | 67.2%          | 87.0%          | 45.4 ms
-Mobilenet_V1_1.0_128_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz)  | 4.3 Mb     | 63.6%          | 84.3%          | 24.9 ms
-Mobilenet_V1_1.0_160_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz)  | 4.3 Mb     | 67.2%          | 86.9%          | 37.4 ms
-Mobilenet_V1_1.0_192_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz)  | 4.3 Mb     | 69.4%          | 88.3%          | 51.9 ms
-Mobilenet_V1_1.0_224_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)  | 4.3 Mb     | 70.2%          | 89.1%          | 70.2 ms
+Model Name                  | Paper_Model_Files                                                                                                                                         | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
+--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb     | 39.5%          | 64.4%          | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb     | 42.8%          | 68.1%          | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb     | 45.7%          | 70.8%          | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb     | 48.2%          | 72.8%          | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz)  | 1.4 Mb     | 54.9%          | 78.1%          | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz)  | 1.4 Mb     | 57.2%          | 80.5%          | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz)  | 1.4 Mb     | 59.9%          | 82.1%          | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz)  | 1.4 Mb     | 61.2%          | 83.2%          | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb     | 55.9%          | 79.1%          | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb     | 62.4%          | 83.7%          | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb     | 66.1%          | 86.2%          | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb     | 66.9%          | 86.9%          | 45.4 ms
+Mobilenet_V1_1.0_128_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz)  | 4.3 Mb     | 63.3%          | 84.1%          | 24.9 ms
+Mobilenet_V1_1.0_160_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz)  | 4.3 Mb     | 66.9%          | 86.7%          | 37.4 ms
+Mobilenet_V1_1.0_192_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz)  | 4.3 Mb     | 69.1%          | 88.1%          | 51.9 ms
+Mobilenet_V1_1.0_224_quant  | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)  | 4.3 Mb     | 70.0%          | 89.0%          | 70.2 ms
+Mobilenet_v2_1.0_224_quant  | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz)              | 3.4 Mb     | 70.8%          | 89.9%          | 80.3 ms
+Inception_v3_quant          | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz)                       | 23 Mb      | 77.5%          | 93.7%          | 637 ms
 
 ## Other models
 
 Model                   | TF Lite FlatBuffer
 ----------------------- | :----------------:
-Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
+[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html),
+[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 77268d7..8ee8382 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -17,7 +17,7 @@
 
 #include <vector>
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5ab53f4..2657bcd 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -21,9 +21,9 @@
 #include <cstring>
 
 #include "tensorflow/contrib/lite/arena_planner.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/graph_info.h"
 #include "tensorflow/contrib/lite/memory_planner.h"
 #include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -123,6 +123,7 @@
   context_.AddTensors = AddTensors;
   context_.tensors = nullptr;
   context_.tensors_size = 0;
+  context_.allow_fp32_relax_to_fp16 = false;
   context_.recommended_num_threads = -1;
   context_.GetExternalContext = GetExternalContext;
   context_.SetExternalContext = SetExternalContext;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 2b1f181..aa2bc4d 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,10 +23,11 @@
 #include <vector>
 
 #include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/memory_planner.h"
 #include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
 
 namespace tflite {
 
@@ -335,6 +336,19 @@
   // Set the number of threads available to the interpreter.
   void SetNumThreads(int num_threads);
 
+  // Allow float16 precision for FP32 calculation when possible.
+  // default: not allow.
+  // WARNING: This is an experimental API and subject to change.
+  void SetAllowFp16PrecisionForFp32(bool allow) {
+    context_.allow_fp32_relax_to_fp16 = allow;
+  }
+
+  // Get the half precision flag.
+  // WARNING: This is an experimental API and subject to change.
+  bool GetAllowFp16PrecisionForFp32() const {
+    return context_.allow_fp32_relax_to_fp16;
+  }
+
   // Allow a delegate to look at the graph and modify the graph to handle
   // parts of the graph themselves. After this is called, the graph may
   // contain new nodes that replace 1 more nodes.
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 5bcf092..cdede43 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -15,7 +15,7 @@
 
 #include "tensorflow/contrib/lite/interpreter.h"
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index e3cea19..6a3f065 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -20,9 +20,6 @@
       - Make sure to install the latest version of Bazel. Some distributions
         ship with Bazel 0.5.4, which is too old.
       - Bazel requires Android Build Tools `26.0.1` or higher.
-      - **Bazel is incompatible with NDK revisions 15 and above,** with revision
-        16 being a compile-breaking change. [Download an older version manually
-        instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites)
       - You also need to install the Android Support Repository, available
         through Android Studio under `Android SDK Manager -> SDK Tools ->
         Android Support Repository`.
@@ -37,8 +34,7 @@
       - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
         you have installed.
       - By default, Android Studio will install the SDK to `~/Android/Sdk` and
-        the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual
-        download until Bazel supports NDK 16. See bullet points under (1)).
+        the NDK to `~/Android/Sdk/ndk-bundle`.
 
 2. Build the app with Bazel. The demo needs C++11:
 
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 06f46fb..781289c 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -35,6 +35,7 @@
         "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
     ],
     main_class = "org.tensorflow.ovic.OvicValidator",
+    tags = ["no_oss"],
     deps = [
         "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java",
     ],
@@ -47,6 +48,7 @@
         "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
     ],
     manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+    tags = ["no_oss"],
     deps = [
         "//tensorflow/contrib/lite/java:tensorflowlite",
         "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
@@ -61,6 +63,7 @@
         "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
     ],
     javacopts = JAVACOPTS,
+    tags = ["no_oss"],
     deps = [
         "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
         "//tensorflow/contrib/lite/java:tensorflowlite_java",
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 55ca47f..06b35d7 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -20,7 +20,7 @@
 #include <stdio.h>
 #include <time.h>
 #include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
 #include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
@@ -124,9 +124,9 @@
  */
 JNIEXPORT void JNICALL
 Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
-                                                           jclass clazz,
-                                                           jlong handle,
-                                                           jint num_threads);
+                                                             jclass clazz,
+                                                             jlong handle,
+                                                             jint num_threads);
 /*
  *  Class:     org_tensorflow_lite_NativeInterpreterWrapper
  *  Method:
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index c020f13..2f73128 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -17,7 +17,7 @@
 #define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
 
 #include <jni.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 #ifdef __cplusplus
 extern "C" {
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 8287115..40f28ae 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -6,7 +6,7 @@
 
 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android")
 
 # Suppress warnings that are introduced by Eigen Tensor.
 EXTRA_EIGEN_COPTS = select({
@@ -66,7 +66,7 @@
     deps = [
         ":op_macros",
         "//tensorflow/contrib/lite:arena_planner",
-        "//tensorflow/contrib/lite:context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels/internal:optimized",
     ],
 )
@@ -82,7 +82,7 @@
     copts = tflite_copts(),
     deps = [
         ":op_macros",
-        "//tensorflow/contrib/lite:context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "@gemmlowp",
     ],
 )
@@ -93,7 +93,7 @@
         "activation_functor.h",
     ],
     deps = [
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ],
 )
 
@@ -113,9 +113,9 @@
         "kernel_util.h",
     ],
     deps = [
-        "//tensorflow/contrib/lite:builtin_op_data",
-        "//tensorflow/contrib/lite:context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels/internal:round",
+        "//tensorflow/contrib/lite/kernels/internal:types",
     ],
 )
 
@@ -147,7 +147,16 @@
 )
 
 cc_library(
-    name = "builtin_ops",
+    name = "padding",
+    srcs = [],
+    hdrs = ["padding.h"],
+    deps = [
+        "//tensorflow/contrib/lite/c:c_api_internal",
+    ],
+)
+
+cc_library(
+    name = "builtin_op_kernels",
     srcs = [
         "activations.cc",
         "add.cc",
@@ -177,6 +186,7 @@
         "gather.cc",
         "hashtable_lookup.cc",
         "l2norm.cc",
+        "layer_norm_lstm.cc",
         "local_response_norm.cc",
         "logical.cc",
         "lsh_projection.cc",
@@ -191,7 +201,7 @@
         "pooling.cc",
         "pow.cc",
         "reduce.cc",
-        "register.cc",
+        "relu1.cc",
         "reshape.cc",
         "resize_bilinear.cc",
         "select.cc",
@@ -215,33 +225,45 @@
         "unpack.cc",
     ],
     hdrs = [
-        "padding.h",
-        "register.h",
     ],
-    copts = tflite_copts() + EXTRA_EIGEN_COPTS,
+    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
+    visibility = ["//visibility:private"],
     deps = [
         ":activation_functor",
         ":eigen_support",
         ":kernel_util",
         ":op_macros",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        ":padding",
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:string_util",
         "//tensorflow/contrib/lite:util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:gemm_support",
         "//tensorflow/contrib/lite/kernels/internal:audio_utils",
         "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
         "//tensorflow/contrib/lite/kernels/internal:optimized",
         "//tensorflow/contrib/lite/kernels/internal:optimized_base",
         "//tensorflow/contrib/lite/kernels/internal:quantization_util",
-        "//tensorflow/contrib/lite/kernels/internal:reference",
         "//tensorflow/contrib/lite/kernels/internal:reference_base",
+        "//tensorflow/contrib/lite/kernels/internal:tensor",
         "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
         "@farmhash_archive//:farmhash",
         "@flatbuffers",
     ],
 )
 
+cc_library(
+    name = "builtin_ops",
+    srcs = ["register.cc"],
+    hdrs = ["register.h"],
+    deps = [
+        ":builtin_op_kernels",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite:util",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+    ],
+)
+
 tf_cc_test(
     name = "audio_spectrogram_test",
     size = "small",
@@ -294,6 +316,23 @@
 )
 
 tf_cc_test(
+    name = "relu1_test",
+    size = "small",
+    srcs = ["relu1_test.cc"],
+    tags = [
+        "no_oss",
+        "tflite_not_portable_ios",
+    ],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+        "@flatbuffers",
+    ],
+)
+
+tf_cc_test(
     name = "activations_test",
     size = "small",
     srcs = ["activations_test.cc"],
@@ -728,8 +767,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -745,8 +784,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -904,6 +943,20 @@
 )
 
 tf_cc_test(
+    name = "layer_norm_lstm_test",
+    size = "small",
+    srcs = ["layer_norm_lstm_test.cc"],
+    tags = ["tflite_not_portable_ios"],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+        "@flatbuffers",
+    ],
+)
+
+tf_cc_test(
     name = "lstm_test",
     size = "small",
     srcs = ["lstm_test.cc"],
@@ -1001,8 +1054,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1104,8 +1157,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1121,8 +1174,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1138,8 +1191,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1155,8 +1208,8 @@
     ],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1169,8 +1222,8 @@
     tags = ["tflite_not_portable_ios"],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
@@ -1196,8 +1249,8 @@
     tags = ["tflite_not_portable_ios"],
     deps = [
         ":builtin_ops",
-        "//tensorflow/contrib/lite:builtin_op_data",
         "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest",
     ],
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cc..e075dc7 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@
 #include <cmath>
 #include <cstdlib>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 9c891fe..b2d9b84 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -200,7 +200,7 @@
   TF_LITE_ENSURE_EQ(context, input->type, output->type);
 
   const int num_dims = NumDimensions(input);
-  TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4);
+  TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
 
   if (input->type == kTfLiteUInt8) {
     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -453,6 +453,19 @@
   Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
 }
 
+// Takes a 3D tensor and perform softmax along the last dimension.
+void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+                    TfLiteSoftmaxParams* params) {
+  const int batch_size = input->dims->data[0];
+  const int intermediate_size = input->dims->data[1];
+  const int input_size = input->dims->data[2];
+  optimized_ops::Softmax(
+      GetTensorData<float>(input),
+      GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+      params->beta, GetTensorData<float>(output),
+      GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
 void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
                         TfLiteSoftmaxParams* params, OpData* data) {
   // TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -480,6 +493,19 @@
                          GetTensorShape({batch_size, 1, 1, input_size}));
 }
 
+void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+                        TfLiteSoftmaxParams* params, OpData* data) {
+  const int batch_size = input->dims->data[0];
+  const int intermediate_size = input->dims->data[1];
+  const int input_size = input->dims->data[2];
+  optimized_ops::Softmax(
+      GetTensorData<uint8_t>(input),
+      GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+      data->input_multiplier, data->input_left_shift, data->diff_min,
+      GetTensorData<uint8_t>(output),
+      GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
 // Takes a 4D tensor and perform softmax along the forth dimension.
 void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
                     TfLiteSoftmaxParams* params) {
@@ -515,6 +541,10 @@
         Softmax2DFloat(input, output, params);
         return kTfLiteOk;
       }
+      if (NumDimensions(input) == 3) {
+        Softmax3DFloat(input, output, params);
+        return kTfLiteOk;
+      }
       if (NumDimensions(input) == 4) {
         Softmax4DFloat(input, output, params);
         return kTfLiteOk;
@@ -533,6 +563,10 @@
         Softmax2DQuantized(input, output, params, data);
         return kTfLiteOk;
       }
+      if (NumDimensions(input) == 3) {
+        Softmax3DQuantized(input, output, params, data);
+        return kTfLiteOk;
+      }
       if (NumDimensions(input) == 4) {
         Softmax4DQuantized(input, output, params, data);
         return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index e577e3a..9fa47e1 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,76 @@
                   kQuantizedTolerance)));
 }
 
+TEST(FloatActivationsOpTest, Softmax3D) {
+  FloatActivationsOpModel m(0.1,
+                            /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+  m.SetInput({
+      0, -6, 2, 4,   // depth = 0
+      3, -2, 10, 1,  // depth = 1
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+                                 .23463, .12877, .28658, .35003,  //
+                                 .22528, .13664, .45365, .18443,  //
+                             })));
+
+  // Same input, but a different shape.
+  FloatActivationsOpModel m2(0.1,
+                             /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
+  m2.SetInput({
+      0, -6,  //
+      2, 4,   //
+      3, -2,  //
+      10, 1,  //
+  });
+  m2.Invoke();
+  EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+                                  0.645656, 0.354344,  //
+                                  0.450166, 0.549834,  //
+                                  0.622459, 0.377541,  //
+                                  0.710949, 0.28905,   //
+                              })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+  QuantizedActivationsOpModel m(
+      0.1,
+      /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+  m.SetInput<uint8_t>({
+      0, -6, 2, 4,   // depth = 0
+      3, -2, 10, 1,  // depth = 1
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+              ElementsAreArray(ArrayFloatNear(
+                  {
+                      .23463, .12877, .28658, .35003,  //
+                      .22528, .13664, .45365, .18443,  //
+                  },
+                  kQuantizedTolerance)));
+
+  // Same input, but a different shape.
+  QuantizedActivationsOpModel m2(
+      0.1,
+      /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+  m2.SetInput<uint8_t>({
+      0, -6,  //
+      2, 4,   //
+      3, -2,  //
+      10, 1,  //
+  });
+  m2.Invoke();
+  EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+              ElementsAreArray(ArrayFloatNear(
+                  {
+                      0.645656, 0.354344,  //
+                      0.450166, 0.549834,  //
+                      0.622459, 0.377541,  //
+                      0.710949, 0.28905,   //
+                  },
+                  kQuantizedTolerance)));
+}
+
 TEST(FloatActivationsOpTest, Softmax1D) {
   FloatActivationsOpModel m(0.1,
                             /*input=*/{TensorType_FLOAT32, {8}});
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index af9b5c7..b4393e8 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 6e05f5a..b91e348 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 1170d84..44ef587 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index c5a5c01..1aa2760 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@
 #include <stddef.h>
 #include <stdint.h>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 4efa9d5..fe2865d 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index af47b33..541f320 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
@@ -104,13 +104,44 @@
 // Cell state tensors of size {n_batch, n_cell}
 constexpr int kBwInputCellStateTensor = 38;
 
+// Auxiliary input and weights when stacking.
+constexpr int kAuxInputTensor = 39;  // Optional
+// Forward weights.
+constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
+constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
+constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
+constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
+// Backward weights.
+constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
+constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
+constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
+constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
+
 // Output tensors.
 constexpr int kFwOutputTensor = 0;
 constexpr int kBwOutputTensor = 1;
 
+// Temporary tensors.
+enum TemporaryTensor {
+  // Scratch buffers for input, forget, etc. gates
+  kFwScratchBuffer = 0,
+  kBwScratchBuffer = 1,
+  // Quantized tensors needed for the hybrid kernel.
+  kInputQuantized = 2,
+  kAuxInputQuantized = 3,  // Quantized tensor needed for auxiliary input.
+  kFwActivationStateQuantized = 4,
+  kBwActivationStateQuantized = 5,
+  kFwCellStateQuantized = 6,
+  kBwCellStateQuantized = 7,
+  kScalingFactors = 8,
+  kProductScalingFactors = 9,
+  kRecoveredCellWeights = 10,
+  kNumTemporaryTensors = 11
+};
+
 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
   auto* scratch_tensor_index = new int;
-  context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+  context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
   return scratch_tensor_index;
 }
 
@@ -131,7 +162,7 @@
     int input_gate_bias_tensor, int forget_gate_bias_tensor,
     int cell_gate_bias_tensor, int output_gate_bias_tensor,
     int projection_weights_tensor, int projection_bias_tensor) {
-  auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+  const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
 
   // Making sure clipping parameters have valid values.
   // == 0 means no clipping
@@ -318,13 +349,14 @@
   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
 
   // Check we have all the inputs and outputs we need.
-  TF_LITE_ENSURE_EQ(context, node->inputs->size, 39);
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
   TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
 
   // Inferring batch size, number of outputs and sequence length and
   // number of cells from the input tensors.
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  TF_LITE_ENSURE(context, input->dims->size > 1);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
   const int max_time = input->dims->data[0];
   const int n_batch = input->dims->data[1];
   const int n_input = input->dims->data[2];
@@ -348,6 +380,48 @@
       context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
                                           n_fw_cell));
 
+  // Get (optional) auxiliary inputs and weights.
+  const TfLiteTensor* aux_input =
+      GetOptionalInputTensor(context, node, kAuxInputTensor);
+  const TfLiteTensor* fw_aux_input_to_input_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_forget_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_cell_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_output_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_input_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_forget_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_cell_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_output_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+  const bool aux_inputs_all_or_none =
+      ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) &&
+       (fw_aux_input_to_forget_weights != nullptr) &&
+       (fw_aux_input_to_output_weights != nullptr) &&
+       (bw_aux_input_to_cell_weights != nullptr) &&
+       (bw_aux_input_to_forget_weights != nullptr) &&
+       (bw_aux_input_to_output_weights != nullptr)) ||
+      ((fw_aux_input_to_cell_weights == nullptr) &&
+       (fw_aux_input_to_forget_weights == nullptr) &&
+       (fw_aux_input_to_output_weights == nullptr) &&
+       (bw_aux_input_to_cell_weights == nullptr) &&
+       (bw_aux_input_to_forget_weights == nullptr) &&
+       (bw_aux_input_to_output_weights == nullptr));
+  TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+  const bool has_aux_input = (aux_input != nullptr);
+
+  if (has_aux_input) {
+    // Check that aux_input has the same dimensions (except last) as the input.
+    TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+    TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+  }
+
   // Get the pointer to output, activation_state and cell_state buffer tensors.
   TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
   TfLiteTensor* fw_activation_state =
@@ -370,16 +444,28 @@
   TF_LITE_ENSURE_OK(context,
                     context->ResizeTensor(context, fw_output, fw_output_size));
 
-  // Create a scratch buffer tensor.
+  // The weights are of consistent type, so it suffices to check one.
+  const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8);
+
   TfLiteIntArrayFree(node->temporaries);
-  node->temporaries = TfLiteIntArrayCreate(2);
-  node->temporaries->data[0] = *scratch_tensor_index;
-  TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0);
+  if (is_hybrid_op) {
+    node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+  } else {
+    node->temporaries = TfLiteIntArrayCreate(2);  // the two scratch buffers.
+  }
+  // Create a scratch buffer tensor.
+  node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
+  TfLiteTensor* fw_scratch_buffer =
+      GetTemporary(context, node, kFwScratchBuffer);
   fw_scratch_buffer->type = input->type;
   fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
 
   const TfLiteTensor* fw_input_to_input_weights =
       GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
+  if (has_aux_input) {
+    TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
+                      fw_input_to_input_weights->dims->data[0]);
+  }
   const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
   TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
   fw_scratch_buffer_size->data[0] = n_batch;
@@ -435,13 +521,19 @@
   TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
 
   // Create a scratch buffer tensor.
-  node->temporaries->data[1] = *(scratch_tensor_index) + 1;
-  TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1);
+  node->temporaries->data[kBwScratchBuffer] =
+      *(scratch_tensor_index) + kBwScratchBuffer;
+  TfLiteTensor* bw_scratch_buffer =
+      GetTemporary(context, node, kBwScratchBuffer);
   bw_scratch_buffer->type = input->type;
   bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
 
   const TfLiteTensor* bw_input_to_input_weights =
       GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
+  if (has_aux_input) {
+    TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
+                      bw_input_to_input_weights->dims->data[0]);
+  }
   const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
   TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
   bw_scratch_buffer_size->data[0] = n_batch;
@@ -454,18 +546,528 @@
   }
   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
                                                    bw_scratch_buffer_size));
+  if (is_hybrid_op) {
+    // Allocate temporary tensors to store quantized values of input, aux_input
+    // (if present), activation_state and cell_state tensors.
+    node->temporaries->data[kInputQuantized] =
+        *scratch_tensor_index + kInputQuantized;
+    TfLiteTensor* input_quantized =
+        GetTemporary(context, node, kInputQuantized);
+    input_quantized->type = kTfLiteUInt8;
+    input_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+      TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+      TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+                                                       input_quantized_size));
+    }
+
+    if (has_aux_input) {
+      node->temporaries->data[kAuxInputQuantized] =
+          *scratch_tensor_index + kAuxInputQuantized;
+      TfLiteTensor* aux_input_quantized =
+          GetTemporary(context, node, kAuxInputQuantized);
+      aux_input_quantized->type = kTfLiteUInt8;
+      aux_input_quantized->allocation_type = kTfLiteArenaRw;
+      if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+        TfLiteIntArray* aux_input_quantized_size =
+            TfLiteIntArrayCopy(aux_input->dims);
+        TF_LITE_ENSURE_OK(context,
+                          context->ResizeTensor(context, aux_input_quantized,
+                                                aux_input_quantized_size));
+      }
+    }
+
+    node->temporaries->data[kFwActivationStateQuantized] =
+        *scratch_tensor_index + kFwActivationStateQuantized;
+    TfLiteTensor* fw_activation_state_quantized =
+        GetTemporary(context, node, kFwActivationStateQuantized);
+    fw_activation_state_quantized->type = kTfLiteUInt8;
+    fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
+                             fw_activation_state->dims)) {
+      TfLiteIntArray* fw_activation_state_quantized_size =
+          TfLiteIntArrayCopy(fw_activation_state->dims);
+      TF_LITE_ENSURE_OK(
+          context, context->ResizeTensor(context, fw_activation_state_quantized,
+                                         fw_activation_state_quantized_size));
+    }
+    node->temporaries->data[kBwActivationStateQuantized] =
+        *scratch_tensor_index + kBwActivationStateQuantized;
+    TfLiteTensor* bw_activation_state_quantized =
+        GetTemporary(context, node, kBwActivationStateQuantized);
+    bw_activation_state_quantized->type = kTfLiteUInt8;
+    bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
+                             bw_activation_state->dims)) {
+      TfLiteIntArray* bw_activation_state_quantized_size =
+          TfLiteIntArrayCopy(bw_activation_state->dims);
+      TF_LITE_ENSURE_OK(
+          context, context->ResizeTensor(context, bw_activation_state_quantized,
+                                         bw_activation_state_quantized_size));
+    }
+    node->temporaries->data[kFwCellStateQuantized] =
+        *scratch_tensor_index + kFwCellStateQuantized;
+    TfLiteTensor* fw_cell_state_quantized =
+        GetTemporary(context, node, kFwCellStateQuantized);
+    fw_cell_state_quantized->type = kTfLiteUInt8;
+    fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
+                             fw_cell_state->dims)) {
+      TfLiteIntArray* fw_cell_state_quantized_size =
+          TfLiteIntArrayCopy(fw_cell_state->dims);
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, fw_cell_state_quantized,
+                                              fw_cell_state_quantized_size));
+    }
+    node->temporaries->data[kBwCellStateQuantized] =
+        *scratch_tensor_index + kBwCellStateQuantized;
+    TfLiteTensor* bw_cell_state_quantized =
+        GetTemporary(context, node, kBwCellStateQuantized);
+    bw_cell_state_quantized->type = kTfLiteUInt8;
+    bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
+                             bw_cell_state->dims)) {
+      TfLiteIntArray* bw_cell_state_quantized_size =
+          TfLiteIntArrayCopy(bw_cell_state->dims);
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, bw_cell_state_quantized,
+                                              bw_cell_state_quantized_size));
+    }
+
+    // Allocate temporary tensors to store scaling factors and product scaling
+    // factors. The latter is a convenience storage which allows to quantize
+    // a vector once (which produces the scaling factors) and multiply it with
+    // different matrices (which requires multiplying the scaling factors with
+    // the scaling factor of the matrix).
+    node->temporaries->data[kScalingFactors] =
+        *scratch_tensor_index + kScalingFactors;
+    TfLiteTensor* scaling_factors =
+        GetTemporary(context, node, kScalingFactors);
+    scaling_factors->type = kTfLiteFloat32;
+    scaling_factors->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+    scaling_factors_size->data[0] = n_batch;
+    if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+      TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+                                                       scaling_factors_size));
+    }
+    node->temporaries->data[kProductScalingFactors] =
+        *scratch_tensor_index + kProductScalingFactors;
+    TfLiteTensor* prod_scaling_factors =
+        GetTemporary(context, node, kProductScalingFactors);
+    prod_scaling_factors->type = kTfLiteFloat32;
+    prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+    prod_scaling_factors_size->data[0] = n_batch;
+    if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+                             prod_scaling_factors_size)) {
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, prod_scaling_factors,
+                                              prod_scaling_factors_size));
+    }
+
+    // Allocate a temporary tensor to store the recovered cell weights. Since
+    // this is used for diagonal matrices, only need to store n_cell values.
+    node->temporaries->data[kRecoveredCellWeights] =
+        *scratch_tensor_index + kRecoveredCellWeights;
+    TfLiteTensor* recovered_cell_weights =
+        GetTemporary(context, node, kRecoveredCellWeights);
+    recovered_cell_weights->type = kTfLiteFloat32;
+    recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+    recovered_cell_weights_size->data[0] = n_fw_cell;
+    if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+                             recovered_cell_weights_size)) {
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, recovered_cell_weights,
+                                              recovered_cell_weights_size));
+    }
+  }
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(
+    const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+    const TfLiteTensor* input_to_forget_weights,
+    const TfLiteTensor* input_to_cell_weights,
+    const TfLiteTensor* input_to_output_weights,
+    const TfLiteTensor* recurrent_to_input_weights,
+    const TfLiteTensor* recurrent_to_forget_weights,
+    const TfLiteTensor* recurrent_to_cell_weights,
+    const TfLiteTensor* recurrent_to_output_weights,
+    const TfLiteTensor* cell_to_input_weights,
+    const TfLiteTensor* cell_to_forget_weights,
+    const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+    const TfLiteTensor* aux_input_to_input_weights,
+    const TfLiteTensor* aux_input_to_forget_weights,
+    const TfLiteTensor* aux_input_to_cell_weights,
+    const TfLiteTensor* aux_input_to_output_weights,
+    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+    const TfLiteLSTMParams* params, bool forward_sequence,
+    TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+    TfLiteTensor* cell_state, TfLiteTensor* output) {
+  const int max_time = input->dims->data[0];
+  const int n_batch = input->dims->data[1];
+  const int n_input = input->dims->data[2];
+  const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existense of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  const bool use_peephole = (cell_to_output_weights != nullptr);
+
+  // Index the scratch buffers pointers to the global scratch buffer.
+  float* input_gate_scratch = nullptr;
+  float* cell_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_scratch = scratch_buffer->data.f;
+    forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer->data.f;
+    cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+  }
+
+  // Check optional tensors, the respective pointers can be null.
+  const float* input_to_input_weights_ptr =
+      (use_cifg) ? nullptr : input_to_input_weights->data.f;
+  const float* recurrent_to_input_weights_ptr =
+      (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+  const float* input_gate_bias_ptr =
+      (use_cifg) ? nullptr : input_gate_bias->data.f;
+  const float* cell_to_input_weights_ptr =
+      (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+  const float* cell_to_forget_weights_ptr =
+      (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+  const float* cell_to_output_weights_ptr =
+      (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+  const float* projection_weights_ptr =
+      (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+  const float* projection_bias_ptr =
+      (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+  float* aux_input_ptr = nullptr;
+  float* aux_input_to_input_weights_ptr = nullptr;
+  float* aux_input_to_forget_weights_ptr = nullptr;
+  float* aux_input_to_cell_weights_ptr = nullptr;
+  float* aux_input_to_output_weights_ptr = nullptr;
+  if (aux_input_size > 0) {
+    aux_input_ptr = aux_input->data.f;
+    aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+    aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+    aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+    aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+  }
+
+  // Loop through the sequence.
+  if (forward_sequence) {
+    for (int t = 0; t < max_time; t++) {
+      const float* input_ptr = input->data.f + t * n_batch * n_input;
+      float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+      kernel_utils::LstmStepWithAuxInput(
+          input_ptr, input_to_input_weights_ptr,
+          input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+          input_to_output_weights->data.f, aux_input_ptr,
+          aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+          aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+          recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+          recurrent_to_cell_weights->data.f,
+          recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+          cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+          input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+          output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          activation_state->data.f, cell_state->data.f, input_gate_scratch,
+          forget_gate_scratch, cell_scratch, output_gate_scratch,
+          output_ptr_time);
+    }
+  } else {
+    // Loop through the sequence backwards.
+    for (int t = max_time - 1; t >= 0; t--) {
+      const float* input_ptr = input->data.f + t * n_batch * n_input;
+      float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+      kernel_utils::LstmStepWithAuxInput(
+          input_ptr, input_to_input_weights_ptr,
+          input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+          input_to_output_weights->data.f, aux_input_ptr,
+          aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+          aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+          recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+          recurrent_to_cell_weights->data.f,
+          recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+          cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+          input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+          output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          activation_state->data.f, cell_state->data.f, input_gate_scratch,
+          forget_gate_scratch, cell_scratch, output_gate_scratch,
+          output_ptr_time);
+    }
+  }
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+    const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+    const TfLiteTensor* input_to_forget_weights,
+    const TfLiteTensor* input_to_cell_weights,
+    const TfLiteTensor* input_to_output_weights,
+    const TfLiteTensor* recurrent_to_input_weights,
+    const TfLiteTensor* recurrent_to_forget_weights,
+    const TfLiteTensor* recurrent_to_cell_weights,
+    const TfLiteTensor* recurrent_to_output_weights,
+    const TfLiteTensor* cell_to_input_weights,
+    const TfLiteTensor* cell_to_forget_weights,
+    const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+    const TfLiteTensor* aux_input_to_input_weights,
+    const TfLiteTensor* aux_input_to_forget_weights,
+    const TfLiteTensor* aux_input_to_cell_weights,
+    const TfLiteTensor* aux_input_to_output_weights,
+    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+    const TfLiteLSTMParams* params, bool forward_sequence,
+    TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+    TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+    TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+    TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+    TfLiteTensor* output_state, TfLiteTensor* cell_state,
+    TfLiteTensor* output) {
+  const int max_time = input->dims->data[0];
+  const int n_batch = input->dims->data[1];
+  const int n_input = input->dims->data[2];
+  const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  const bool use_peephole = (cell_to_output_weights != nullptr);
+
+  float* input_gate_scratch = nullptr;
+  float* cell_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_scratch = scratch_buffer->data.f;
+    forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer->data.f;
+    cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+  }
+
+  // Check optional tensors, the respective pointers can be null.
+  int8_t* input_to_input_weights_ptr = nullptr;
+  float input_to_input_weights_scale = 1.0f;
+  int8_t* recurrent_to_input_weights_ptr = nullptr;
+  float recurrent_to_input_weights_scale = 1.0f;
+  float* input_gate_bias_ptr = nullptr;
+  if (!use_cifg) {
+    input_to_input_weights_ptr =
+        reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+    recurrent_to_input_weights_ptr =
+        reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+    input_gate_bias_ptr = input_gate_bias->data.f;
+    input_to_input_weights_scale = input_to_input_weights->params.scale;
+    recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+  }
+
+  int8_t* cell_to_input_weights_ptr = nullptr;
+  int8_t* cell_to_forget_weights_ptr = nullptr;
+  int8_t* cell_to_output_weights_ptr = nullptr;
+  float cell_to_input_weights_scale = 1.0f;
+  float cell_to_forget_weights_scale = 1.0f;
+  float cell_to_output_weights_scale = 1.0f;
+  if (use_peephole) {
+    if (!use_cifg) {
+      cell_to_input_weights_ptr =
+          reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+      cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+    }
+    cell_to_forget_weights_ptr =
+        reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+    cell_to_output_weights_ptr =
+        reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+    cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+    cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+  }
+
+  const int8_t* projection_weights_ptr =
+      (projection_weights == nullptr)
+          ? nullptr
+          : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+  const float projection_weights_scale =
+      (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+  const float* projection_bias_ptr =
+      (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+  // Required tensors, pointers are non-null.
+  const int8_t* input_to_forget_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+  const float input_to_forget_weights_scale =
+      input_to_forget_weights->params.scale;
+  const int8_t* input_to_cell_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+  const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+  const int8_t* input_to_output_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+  const float input_to_output_weights_scale =
+      input_to_output_weights->params.scale;
+  const int8_t* recurrent_to_forget_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+  const float recurrent_to_forget_weights_scale =
+      recurrent_to_forget_weights->params.scale;
+  const int8_t* recurrent_to_cell_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+  const float recurrent_to_cell_weights_scale =
+      recurrent_to_cell_weights->params.scale;
+  const int8_t* recurrent_to_output_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+  const float recurrent_to_output_weights_scale =
+      recurrent_to_output_weights->params.scale;
+  const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+  const float* cell_bias_ptr = cell_bias->data.f;
+  const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+  float* output_state_ptr = output_state->data.f;
+  float* cell_state_ptr = cell_state->data.f;
+
+  // Temporary storage for quantized values and scaling factors.
+  int8_t* quantized_input_ptr =
+      reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+  int8_t* quantized_aux_input_ptr =
+      (aux_input_quantized == nullptr)
+          ? nullptr
+          : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+  int8_t* quantized_output_state_ptr =
+      reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+  int8_t* quantized_cell_state_ptr =
+      reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+  float* scaling_factors_ptr = scaling_factors->data.f;
+  float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+  float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+  // Auxiliary input and weights.
+  float* aux_input_ptr = nullptr;
+  int8_t* aux_input_to_input_weights_ptr = nullptr;
+  int8_t* aux_input_to_forget_weights_ptr = nullptr;
+  int8_t* aux_input_to_cell_weights_ptr = nullptr;
+  int8_t* aux_input_to_output_weights_ptr = nullptr;
+  float aux_input_to_input_weights_scale = 0.0f;
+  float aux_input_to_forget_weights_scale = 0.0f;
+  float aux_input_to_cell_weights_scale = 0.0f;
+  float aux_input_to_output_weights_scale = 0.0f;
+  if (aux_input_size > 0) {
+    aux_input_ptr = aux_input->data.f;
+    aux_input_to_input_weights_ptr =
+        reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+    aux_input_to_forget_weights_ptr =
+        reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+    aux_input_to_cell_weights_ptr =
+        reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+    aux_input_to_output_weights_ptr =
+        reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+    aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+    aux_input_to_forget_weights_scale =
+        aux_input_to_forget_weights->params.scale;
+    aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+    aux_input_to_output_weights_scale =
+        aux_input_to_output_weights->params.scale;
+  }
+  if (forward_sequence) {
+    // Feed the sequence into the LSTM step-by-step.
+    for (int t = 0; t < max_time; t++) {
+      const float* input_ptr = input->data.f + t * n_batch * n_input;
+      float* output_ptr = output->data.f + t * n_batch * n_output;
+
+      kernel_utils::LstmStepWithAuxInput(
+          input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+          input_to_forget_weights_ptr, input_to_forget_weights_scale,
+          input_to_cell_weights_ptr, input_to_cell_weights_scale,
+          input_to_output_weights_ptr, input_to_output_weights_scale,
+          aux_input_ptr, aux_input_to_input_weights_ptr,
+          aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+          aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+          aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+          aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+          recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+          recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+          recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+          recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+          cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+          cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+          cell_to_output_weights_scale, input_gate_bias_ptr,
+          forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+          projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          input_gate_scratch, forget_gate_scratch, cell_scratch,
+          output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+          recovered_cell_weights_ptr, quantized_input_ptr,
+          quantized_aux_input_ptr, quantized_output_state_ptr,
+          quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+          output_ptr);
+    }
+  } else {
+    // Loop through the sequence backwards.
+    for (int t = max_time - 1; t >= 0; t--) {
+      const float* input_ptr = input->data.f + t * n_batch * n_input;
+      float* output_ptr = output->data.f + t * n_batch * n_output;
+
+      kernel_utils::LstmStepWithAuxInput(
+          input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+          input_to_forget_weights_ptr, input_to_forget_weights_scale,
+          input_to_cell_weights_ptr, input_to_cell_weights_scale,
+          input_to_output_weights_ptr, input_to_output_weights_scale,
+          aux_input_ptr, aux_input_to_input_weights_ptr,
+          aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+          aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+          aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+          aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+          recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+          recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+          recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+          recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+          cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+          cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+          cell_to_output_weights_scale, input_gate_bias_ptr,
+          forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+          projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+          params, n_batch, n_cell, n_input, aux_input_size, n_output,
+          input_gate_scratch, forget_gate_scratch, cell_scratch,
+          output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+          recovered_cell_weights_ptr, quantized_input_ptr,
+          quantized_aux_input_ptr, quantized_output_state_ptr,
+          quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+          output_ptr);
+    }
+  }
+
   return kTfLiteOk;
 }
 
 // The LSTM Op engine.
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
-  auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+  const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
 
   // Input tensor.
   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-  const int max_time = input->dims->data[0];
-  const int n_batch = input->dims->data[1];
-  const int n_input = input->dims->data[2];
 
   // Tensors for the forward cell.
   const TfLiteTensor* fw_input_to_input_weights =
@@ -553,155 +1155,134 @@
   const TfLiteTensor* bw_projection_bias =
       GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
 
+  // State tensors.
   TfLiteTensor* bw_activation_state =
       GetVariableInput(context, node, kBwInputActivationStateTensor);
   TfLiteTensor* bw_cell_state =
       GetVariableInput(context, node, kBwInputCellStateTensor);
   TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
 
-  // n_cell and n_output will be the same size when there is no projection.
-  const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
-  const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
-
-  // Since we have already checked that weights are all there or none, we can
-  // check the existense of only one to the get the condition.
-  const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
-  const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr);
-
-  // Index the scratch buffers pointers to the global scratch buffer.
+  // Temporary tensors.
   TfLiteTensor* fw_scratch_buffer =
-      &context->tensors[node->temporaries->data[0]];
-  float* fw_input_gate_scratch = nullptr;
-  float* fw_cell_scratch = nullptr;
-  float* fw_forget_gate_scratch = nullptr;
-  float* fw_output_gate_scratch = nullptr;
-  if (fw_use_cifg) {
-    fw_cell_scratch = fw_scratch_buffer->data.f;
-    fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
-    fw_output_gate_scratch =
-        fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
-  } else {
-    fw_input_gate_scratch = fw_scratch_buffer->data.f;
-    fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
-    fw_forget_gate_scratch =
-        fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
-    fw_output_gate_scratch =
-        fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch;
-  }
-
-  // Check optional tensors, the respective pointers can be null.
-  const float* fw_input_to_input_weights_ptr =
-      (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f;
-  const float* fw_recurrent_to_input_weights_ptr =
-      (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f;
-  const float* fw_input_gate_bias_ptr =
-      (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f;
-  const float* fw_cell_to_input_weights_ptr =
-      (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f
-                                        : nullptr;
-  const float* fw_cell_to_forget_weights_ptr =
-      (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr;
-  const float* fw_cell_to_output_weights_ptr =
-      (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr;
-  const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr)
-                                               ? nullptr
-                                               : fw_projection_weights->data.f;
-  const float* fw_projection_bias_ptr =
-      (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f;
-
-  // Loop through the sequence.
-  for (int t = 0; t < max_time; t++) {
-    const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
-    float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
-
-    kernel_utils::LstmStep(
-        input_ptr_batch, fw_input_to_input_weights_ptr,
-        fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
-        fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
-        fw_recurrent_to_forget_weights->data.f,
-        fw_recurrent_to_cell_weights->data.f,
-        fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr,
-        fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
-        fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
-        fw_cell_bias->data.f, fw_output_gate_bias->data.f,
-        fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
-        n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f,
-        fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
-        fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
-  }
-
-  // n_cell and n_output will be the same size when there is no projection.
-  const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
-  const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
-
-  // Since we have already checked that weights are all there or none, we can
-  // check the existense of only one to the get the condition.
-  const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
-  const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr);
-
-  // Index the scratch buffers pointers to the global scratch buffer.
+      GetTemporary(context, node, kFwScratchBuffer);
   TfLiteTensor* bw_scratch_buffer =
-      &context->tensors[node->temporaries->data[1]];
-  float* bw_input_gate_scratch = nullptr;
-  float* bw_cell_scratch = nullptr;
-  float* bw_forget_gate_scratch = nullptr;
-  float* bw_output_gate_scratch = nullptr;
-  if (bw_use_cifg) {
-    bw_cell_scratch = bw_scratch_buffer->data.f;
-    bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
-    bw_output_gate_scratch =
-        bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
-  } else {
-    bw_input_gate_scratch = bw_scratch_buffer->data.f;
-    bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
-    bw_forget_gate_scratch =
-        bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
-    bw_output_gate_scratch =
-        bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch;
+      GetTemporary(context, node, kBwScratchBuffer);
+
+  // (Optional) auxiliary inputs.
+  const TfLiteTensor* aux_input =
+      GetOptionalInputTensor(context, node, kAuxInputTensor);
+  const TfLiteTensor* fw_aux_input_to_input_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_forget_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_cell_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+  const TfLiteTensor* fw_aux_input_to_output_weights =
+      GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_input_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_forget_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_cell_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+  const TfLiteTensor* bw_aux_input_to_output_weights =
+      GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+  switch (fw_input_to_output_weights->type) {
+    case kTfLiteFloat32: {
+      TfLiteStatus fw_pass_status = EvalFloat(
+          input, fw_input_to_input_weights, fw_input_to_forget_weights,
+          fw_input_to_cell_weights, fw_input_to_output_weights,
+          fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+          fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+          fw_cell_to_input_weights, fw_cell_to_forget_weights,
+          fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+          fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+          fw_aux_input_to_output_weights, fw_input_gate_bias,
+          fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+          fw_projection_weights, fw_projection_bias, params,
+          /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
+          fw_cell_state, fw_output);
+      TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+      TfLiteStatus bw_pass_status = EvalFloat(
+          input, bw_input_to_input_weights, bw_input_to_forget_weights,
+          bw_input_to_cell_weights, bw_input_to_output_weights,
+          bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+          bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+          bw_cell_to_input_weights, bw_cell_to_forget_weights,
+          bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
+          bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
+          bw_aux_input_to_output_weights, bw_input_gate_bias,
+          bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+          bw_projection_weights, bw_projection_bias, params,
+          /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
+          bw_cell_state, bw_output);
+      TF_LITE_ENSURE_OK(context, bw_pass_status);
+      return kTfLiteOk;
+    }
+    case kTfLiteUInt8: {
+      TfLiteTensor* input_quantized =
+          GetTemporary(context, node, kInputQuantized);
+      TfLiteTensor* aux_input_quantized =
+          GetTemporary(context, node, kAuxInputQuantized);
+      TfLiteTensor* fw_activation_state_quantized =
+          GetTemporary(context, node, kFwActivationStateQuantized);
+      TfLiteTensor* bw_activation_state_quantized =
+          GetTemporary(context, node, kBwActivationStateQuantized);
+      TfLiteTensor* fw_cell_state_quantized =
+          GetTemporary(context, node, kFwCellStateQuantized);
+      TfLiteTensor* bw_cell_state_quantized =
+          GetTemporary(context, node, kBwCellStateQuantized);
+      TfLiteTensor* scaling_factors =
+          GetTemporary(context, node, kScalingFactors);
+      TfLiteTensor* prod_scaling_factors =
+          GetTemporary(context, node, kProductScalingFactors);
+      TfLiteTensor* recovered_cell_weights =
+          GetTemporary(context, node, kRecoveredCellWeights);
+
+      TfLiteStatus fw_pass_status = EvalHybrid(
+          input, fw_input_to_input_weights, fw_input_to_forget_weights,
+          fw_input_to_cell_weights, fw_input_to_output_weights,
+          fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+          fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+          fw_cell_to_input_weights, fw_cell_to_forget_weights,
+          fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+          fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+          fw_aux_input_to_output_weights, fw_input_gate_bias,
+          fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+          fw_projection_weights, fw_projection_bias, params,
+          /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
+          prod_scaling_factors, recovered_cell_weights, input_quantized,
+          aux_input_quantized, fw_activation_state_quantized,
+          fw_cell_state_quantized, fw_activation_state, fw_cell_state,
+          fw_output);
+      TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+      TfLiteStatus bw_pass_status = EvalHybrid(
+          input, bw_input_to_input_weights, bw_input_to_forget_weights,
+          bw_input_to_cell_weights, bw_input_to_output_weights,
+          bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+          bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+          bw_cell_to_input_weights, bw_cell_to_forget_weights,
+          bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+          fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+          fw_aux_input_to_output_weights, bw_input_gate_bias,
+          bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+          bw_projection_weights, bw_projection_bias, params,
+          /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
+          prod_scaling_factors, recovered_cell_weights, input_quantized,
+          aux_input_quantized, bw_activation_state_quantized,
+          bw_cell_state_quantized, bw_activation_state, bw_cell_state,
+          bw_output);
+      TF_LITE_ENSURE_OK(context, bw_pass_status);
+      return kTfLiteOk;
+    }
+    default:
+      context->ReportError(context, "Type %d is not currently supported.",
+                           fw_input_to_output_weights->type);
+      return kTfLiteError;
   }
-
-  // Check optional tensors, the respective pointers can be null.
-  const float* bw_input_to_input_weights_ptr =
-      (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f;
-  const float* bw_recurrent_to_input_weights_ptr =
-      (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f;
-  const float* bw_input_gate_bias_ptr =
-      (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f;
-  const float* bw_cell_to_input_weights_ptr =
-      (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f
-                                        : nullptr;
-  const float* bw_cell_to_forget_weights_ptr =
-      (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr;
-  const float* bw_cell_to_output_weights_ptr =
-      (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr;
-  const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr)
-                                               ? nullptr
-                                               : bw_projection_weights->data.f;
-  const float* bw_projection_bias_ptr =
-      (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f;
-
-  // Loop through the sequence backwards.
-  for (int t = max_time - 1; t >= 0; t--) {
-    const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
-    float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
-
-    kernel_utils::LstmStep(
-        input_ptr_batch, bw_input_to_input_weights_ptr,
-        bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
-        bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
-        bw_recurrent_to_forget_weights->data.f,
-        bw_recurrent_to_cell_weights->data.f,
-        bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr,
-        bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
-        bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
-        bw_cell_bias->data.f, bw_output_gate_bias->data.f,
-        bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
-        n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f,
-        bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
-        bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
-  }
-
-  // Backward step.
   return kTfLiteOk;
 }
 
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index d058fab..74ba802 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -177,6 +177,16 @@
 
     bw_output_ = AddOutput(TensorType_FLOAT32);
 
+    aux_input_ = AddNullInput();
+    fw_aux_input_to_input_weights_ = AddNullInput();
+    fw_aux_input_to_forget_weights_ = AddNullInput();
+    fw_aux_input_to_cell_weights_ = AddNullInput();
+    fw_aux_input_to_output_weights_ = AddNullInput();
+    bw_aux_input_to_input_weights_ = AddNullInput();
+    bw_aux_input_to_forget_weights_ = AddNullInput();
+    bw_aux_input_to_cell_weights_ = AddNullInput();
+    bw_aux_input_to_output_weights_ = AddNullInput();
+
     SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
                  BuiltinOptions_LSTMOptions,
                  CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
@@ -340,6 +350,16 @@
   int fw_output_;
   int bw_output_;
 
+  int aux_input_;
+  int fw_aux_input_to_input_weights_;
+  int fw_aux_input_to_forget_weights_;
+  int fw_aux_input_to_cell_weights_;
+  int fw_aux_input_to_output_weights_;
+  int bw_aux_input_to_input_weights_;
+  int bw_aux_input_to_forget_weights_;
+  int bw_aux_input_to_cell_weights_;
+  int bw_aux_input_to_output_weights_;
+
   int n_batch_;
   int n_input_;
   int n_fw_cell_;
@@ -415,6 +435,16 @@
 
           {n_batch, n_output},  // activation_state tensor
           {n_batch, n_cell},    // cell_state tensor
+
+          {n_batch, sequence_length, 0},  // aux_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_forget tensor
+          {n_cell, 0},                    // aux_fw_input_to_cell tensor
+          {n_cell, 0},                    // aux_fw_input_to_output tensor
+          {n_cell, 0},                    // aux_bw_input_to_input tensor
+          {n_cell, 0},                    // aux_bw_input_to_forget tensor
+          {n_cell, 0},                    // aux_bw_input_to_cell tensor
+          {n_cell, 0},                    // aux_bw_input_to_output tensor
       });
 
   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -562,6 +592,16 @@
 
           {n_batch, n_output},  // activation_state tensor
           {n_batch, n_cell},    // cell_state tensor
+
+          {n_batch, sequence_length, 0},  // aux_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_forget tensor
+          {n_cell, 0},                    // aux_fw_input_to_cell tensor
+          {n_cell, 0},                    // aux_fw_input_to_output tensor
+          {n_cell, 0},                    // aux_bw_input_to_input tensor
+          {n_cell, 0},                    // aux_bw_input_to_forget tensor
+          {n_cell, 0},                    // aux_bw_input_to_cell tensor
+          {n_cell, 0},                    // aux_bw_input_to_output tensor
       });
 
   lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -709,6 +749,16 @@
 
           {n_batch, n_output},  // activation_state tensor
           {n_batch, n_cell},    // cell_state tensor
+
+          {n_batch, sequence_length, 0},  // aux_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_forget tensor
+          {n_cell, 0},                    // aux_fw_input_to_cell tensor
+          {n_cell, 0},                    // aux_fw_input_to_output tensor
+          {n_cell, 0},                    // aux_bw_input_to_input tensor
+          {n_cell, 0},                    // aux_bw_input_to_forget tensor
+          {n_cell, 0},                    // aux_bw_input_to_cell tensor
+          {n_cell, 0},                    // aux_bw_input_to_output tensor
       });
 
   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -848,6 +898,16 @@
 
           {n_batch, n_output},  // activation_state tensor
           {n_batch, n_cell},    // cell_state tensor
+
+          {n_batch, sequence_length, 0},  // aux_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_forget tensor
+          {n_cell, 0},                    // aux_fw_input_to_cell tensor
+          {n_cell, 0},                    // aux_fw_input_to_output tensor
+          {n_cell, 0},                    // aux_bw_input_to_input tensor
+          {n_cell, 0},                    // aux_bw_input_to_forget tensor
+          {n_cell, 0},                    // aux_bw_input_to_cell tensor
+          {n_cell, 0},                    // aux_bw_input_to_output tensor
       });
 
   lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -987,6 +1047,16 @@
 
           {n_batch, n_output},  // activation_state tensor
           {n_batch, n_cell},    // cell_state tensor
+
+          {n_batch, sequence_length, 0},  // aux_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_input tensor
+          {n_cell, 0},                    // aux_fw_input_to_forget tensor
+          {n_cell, 0},                    // aux_fw_input_to_cell tensor
+          {n_cell, 0},                    // aux_fw_input_to_output tensor
+          {n_cell, 0},                    // aux_bw_input_to_input tensor
+          {n_cell, 0},                    // aux_bw_input_to_forget tensor
+          {n_cell, 0},                    // aux_bw_input_to_cell tensor
+          {n_cell, 0},                    // aux_bw_input_to_output tensor
       });
 
   lstm.SetInputToInputWeights(
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index d988ef8..2f896c5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 8dd48af..a797214 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -15,8 +15,8 @@
 #include <string.h>
 #include <algorithm>
 #include <complex>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 8b4d778..4cd9634 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 605a20a..25ea556 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 3ed0cdb..ab6bdae 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -20,8 +20,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/eigen_support.h"
 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 411615a..f7e6f08 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -177,6 +177,30 @@
                              }));
 }
 
+TEST_P(ConvolutionOpTest, InputAndFilterSameWidthHeight) {
+  ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+                       {TensorType_FLOAT32, {1, 2, 4, 1}},
+                       {TensorType_FLOAT32, {}});
+
+  m.SetInput({
+      // First batch
+      1, 1, 1, 1,  // row = 1
+      2, 2, 2, 2,  // row = 2
+      // Second batch
+      1, 2, 3, 4,  // row = 1
+      1, 2, 3, 4,  // row = 2
+  });
+  m.SetFilter({
+      1, 2, 3, 4,    // row = 1
+      -1, -1, 1, 1,  // row = 2
+  });
+  m.SetBias({0});
+
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 34}));
+}
+
 TEST_P(ConvolutionOpTest, PointwiseFloat32) {
   ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
                        {TensorType_FLOAT32, {1, 1, 1, 2}},
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 2151815..3e1ce60 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
@@ -126,23 +126,28 @@
 
   // Matching GetWindowedOutputSize in TensorFlow.
   auto padding = params->padding;
-  auto compute_out_size = [padding](int imageSize, int filterSize,
-                                    int stride) -> int {
+  auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+                                    int dilation_rate) -> int {
+    int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
     return padding == kTfLitePaddingSame
-               ? (imageSize + stride - 1) / stride
+               ? (image_size + stride - 1) / stride
                : padding == kTfLitePaddingValid
-                     ? (imageSize - filterSize + stride) / stride
+                     ? (image_size - effective_filter_size + stride) / stride
                      : 0;
   };
 
-  int out_width = compute_out_size(width, filter_width, params->stride_width);
+  int out_width = compute_out_size(width, filter_width, params->stride_width,
+                                   params->dilation_width_factor);
   int out_height =
-      compute_out_size(height, filter_height, params->stride_height);
+      compute_out_size(height, filter_height, params->stride_height,
+                       params->dilation_height_factor);
 
-  data->padding.height = ComputePadding(params->stride_height, 1, height,
-                                        filter_height, out_height);
+  data->padding.height =
+      ComputePadding(params->stride_height, params->dilation_height_factor,
+                     height, filter_height, out_height);
   data->padding.width =
-      ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+      ComputePadding(params->stride_width, params->dilation_width_factor, width,
+                     filter_width, out_width);
 
   // Note that quantized inference requires that all tensors have their
   // parameters set. This is usually done during quantized training.
@@ -177,8 +182,19 @@
 
   void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
                          const Dims<4>&, const float*, const Dims<4>&, int, int,
-                         int, int, int, float, float, float*, const Dims<4>&);
-  if (kernel_type == kReference) {
+                         int, int, int, int, int, float, float, float*,
+                         const Dims<4>&);
+  KernelType effective_kernel_type;
+  // TODO(suharshs): Currently only the reference implementation supports
+  // dilations.
+  if ((params->dilation_width_factor != 1) ||
+      (params->dilation_height_factor != 1)) {
+    effective_kernel_type = kReference;
+  } else {
+    effective_kernel_type = kernel_type;
+  }
+
+  if (effective_kernel_type == kReference) {
     depthwise_conv = &reference_ops::DepthwiseConv;
   } else {
     depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -188,7 +204,8 @@
       GetTensorData<float>(input), GetTensorDims(input),
       GetTensorData<float>(filter), GetTensorDims(filter),
       GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
-      params->stride_height, data->padding.width, data->padding.height,
+      params->stride_height, params->dilation_width_factor,
+      params->dilation_height_factor, data->padding.width, data->padding.height,
       params->depth_multiplier, output_activation_min, output_activation_max,
       GetTensorData<float>(output), GetTensorDims(output));
 }
@@ -204,9 +221,20 @@
 
   void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
                          const Dims<4>&, int32, const int32*, const Dims<4>&,
-                         int, int, int, int, int, int32, int32, int, int32,
-                         int32, uint8*, const Dims<4>&);
-  if (kernel_type == kReference) {
+                         int, int, int, int, int, int, int, int32, int32, int,
+                         int32, int32, uint8*, const Dims<4>&);
+
+  KernelType effective_kernel_type;
+  // TODO(suharshs): Currently only the reference implementation supports
+  // dilations.
+  if ((params->dilation_width_factor != 1) ||
+      (params->dilation_height_factor != 1)) {
+    effective_kernel_type = kReference;
+  } else {
+    effective_kernel_type = kernel_type;
+  }
+
+  if (effective_kernel_type == kReference) {
     depthwise_conv = &reference_ops::DepthwiseConv;
   } else {
     depthwise_conv = &optimized_ops::DepthwiseConv;
@@ -216,7 +244,8 @@
       GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
       GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
       GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
-      params->stride_height, data->padding.width, data->padding.height,
+      params->stride_height, params->dilation_width_factor,
+      params->dilation_height_factor, data->padding.width, data->padding.height,
       params->depth_multiplier, output_offset, data->output_multiplier,
       data->output_shift, data->output_activation_min,
       data->output_activation_max, GetTensorData<uint8_t>(output),
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index c00cafb..2af26ab 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -30,7 +30,8 @@
   // stride values.
   BaseDepthwiseConvolutionOpModel(const TensorData& input,
                                   const TensorData& filter,
-                                  const TensorData& output) {
+                                  const TensorData& output,
+                                  int dilation_factor = 1) {
     input_ = AddInput(input);
     filter_ = AddInput(filter);
 
@@ -56,7 +57,8 @@
         BuiltinOperator_DEPTHWISE_CONV_2D,
         BuiltinOptions_DepthwiseConv2DOptions,
         CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
-                                     ActivationFunctionType_NONE)
+                                     ActivationFunctionType_NONE,
+                                     dilation_factor, dilation_factor)
             .Union());
 
     BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
@@ -110,6 +112,58 @@
                              }));
 }
 
+TEST(DepthwiseConvolutionOpTest, SimpleDilatedTest) {
+  const int depth = 1;
+  const int image_width = 9;
+  const int image_height = 9;
+  const int image_batch_count = 1;
+  const int filter_size = 3;
+  const int filter_count = 1;
+  const int dilation_factor = 3;
+  DepthwiseConvolutionOpModel m(
+      {TensorType_FLOAT32,
+       {image_batch_count, image_height, image_width, depth}},
+      {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+      {TensorType_FLOAT32, {}}, dilation_factor);
+
+  // The image matrix is:
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // clang-format off
+  m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0});
+  // clang-format on
+  // The filter matrix is:
+  // | 1 | 2 | 3 |
+  // | 4 | 5 | 6 |
+  // | 7 | 8 | 9 |
+  m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+  // No bias for this test.
+  m.SetBias({0});
+  m.Invoke();
+
+  // Since the dilation rate is 3 this will reduce the size of the output from
+  // 10x10 to 3x3 of all 5s. Specifically:
+  // | 5 | 5 | 5 |
+  // | 5 | 5 | 5 |
+  // | 5 | 5 | 5 |
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
 class QuantizedDepthwiseConvolutionOpModel
     : public BaseDepthwiseConvolutionOpModel {
  public:
@@ -207,6 +261,64 @@
               ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
 }
 
+TEST(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTest) {
+  const int depth = 1;
+  const int image_width = 9;
+  const int image_height = 9;
+  const int image_batch_count = 1;
+  const int filter_size = 3;
+  const int filter_count = 1;
+  const int dilation_factor = 3;
+  QuantizedDepthwiseConvolutionOpModel m(
+      {TensorType_UINT8,
+       {image_batch_count, image_height, image_width, depth},
+       0,
+       255},
+      {TensorType_UINT8,
+       {depth, filter_size, filter_size, filter_count},
+       0,
+       255},
+      {TensorType_UINT8, {}, 0, 255}, dilation_factor);
+
+  // The image matrix is:
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+  // clang-format off
+  m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 1, 1, 1, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0,
+              0, 0, 0, 0, 0, 0, 0, 0, 0});
+  // clang-format on
+  // The filter matrix is:
+  // | 1 | 2 | 3 |
+  // | 4 | 5 | 6 |
+  // | 7 | 8 | 9 |
+  m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+  // No bias for this test.
+  m.SetBias({0});
+  m.Invoke();
+
+  // Since the dilation rate is 3 this will reduce the size of the output from
+  // 10x10 to 3x3 of all 5s. Specifically:
+  // | 5 | 5 | 5 |
+  // | 5 | 5 | 5 |
+  // | 5 | 5 | 5 |
+  EXPECT_THAT(m.GetDequantizedOutput(),
+              ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
 }  // namespace
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 2b0f044..3a08f48 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@
 #include <string.h>
 #include <vector>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index 136697f..d290663 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -16,8 +16,8 @@
 #include <numeric>
 #include <vector>
 #include "flatbuffers/flexbuffers.h"  // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d7420dd..7945c09 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index ec77856..feb1543 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,10 +15,10 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace EigenForTFLite {
-class ThreadPoolDevice;
+struct ThreadPoolDevice;
 }
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index e19779e..8c624b3 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 
 #include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 
@@ -90,6 +90,10 @@
   return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
 }
 
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+  return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
   return EvalLogical(context, node, [](bool v) { return !v; });
 }
@@ -129,6 +133,14 @@
   return &r;
 }
 
+TfLiteRegistration* Register_SQUARE() {
+  static TfLiteRegistration r = {
+      /*init=*/nullptr, /*free=*/nullptr,
+      elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+      elementwise::SquareEval};
+  return &r;
+}
+
 TfLiteRegistration* Register_LOGICAL_NOT() {
   static TfLiteRegistration r = {
       /*init=*/nullptr, /*free=*/nullptr,
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index b9d7d73..5dd89a0 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -92,6 +92,15 @@
   EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
 }
 
+TEST(ElementWise, Square) {
+  ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
+  m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
+  m.Invoke();
+  EXPECT_THAT(m.ExtractVector<float>(m.output()),
+              ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
+  EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
 TEST(ElementWise, LogicalNot) {
   ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
   m.PopulateTensor<bool>(m.input(), {true, false, true, false});
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index b2dff87..fe33f98 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -37,8 +37,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be369..aa75b03 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@
 #include <algorithm>
 #include <cmath>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdf..673e7be 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
index ed33012..fa1140b 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -15,8 +15,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
index 50dc860..a3bc181 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -14,7 +14,7 @@
 limitations under the License.
 ==============================================================================*/
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index 0ef1a50..f9bc374 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index f7d5f51..59ff77f 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
index 75cf19a..5d62cd2 100644
--- a/tensorflow/contrib/lite/kernels/floor_div.cc
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index eaf5a67..7a71fcc 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -20,8 +20,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 2b2a9e6..badd2de 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 #include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 1d42929..1b48884 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772..43cd2b3 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
 
 #include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 namespace gemm_support {
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index f37c66a..c0b3c3c 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -39,8 +39,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 #include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 464163b..a6fd4ac 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -163,7 +163,7 @@
         ":tensor_utils",
         "//third_party/eigen3",
         "@gemmlowp",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ] + select({
         ":haswell": tflite_deps_intel,
         ":ios_x86_64": tflite_deps_intel,
@@ -198,7 +198,7 @@
         ":round",
         "//third_party/eigen3",
         "@gemmlowp",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ] + select({
         ":haswell": tflite_deps_intel,
         ":ios_x86_64": tflite_deps_intel,
@@ -220,13 +220,15 @@
         "optimized/eigen_spatial_convolutions.h",
         "optimized/eigen_tensor_reduced_instantiations_oss.h",
         "optimized/multithreaded_conv.h",
+        # FIXME(petewarden) - This should be removed, since it's a header from the
+        # :tensor dependency below.
         "tensor.h",
     ],
     deps = [
         ":optimized_base",
+        ":tensor",
         ":types",
-        "//tensorflow/contrib/lite:builtin_op_data",
-        "//tensorflow/contrib/lite:context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//third_party/eigen3",
     ],
 )
@@ -236,7 +238,7 @@
     srcs = ["tensor_test.cc"],
     tags = ["no_oss"],
     deps = [
-        ":reference",
+        ":tensor",
         "@com_google_googletest//:gtest",
     ],
 )
@@ -296,7 +298,7 @@
         ":strided_slice_logic",
         ":types",
         "@gemmlowp",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ] + select({
         ":haswell": tflite_deps_intel,
         ":ios_x86_64": tflite_deps_intel,
@@ -326,7 +328,7 @@
         ":strided_slice_logic",
         ":types",
         "@gemmlowp",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ] + select({
         ":haswell": tflite_deps_intel,
         ":ios_x86_64": tflite_deps_intel,
@@ -341,11 +343,27 @@
 )
 
 cc_library(
-    name = "reference",
-    hdrs = ["tensor.h"],
+    name = "tensor",
+    hdrs = [
+        "tensor.h",
+        "tensor_ctypes.h",
+    ],
     deps = [
         ":types",
-        "//tensorflow/contrib/lite:context",
+        "//tensorflow/contrib/lite/c:c_api_internal",
+    ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
+    name = "reference",
+    hdrs = [
+        "tensor.h",
+        "tensor_ctypes.h",
+    ],
+    deps = [
+        ":types",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ],
 )
 
@@ -359,7 +377,7 @@
     ],
     deps = [
         ":round",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:activation_functor",
         "//tensorflow/contrib/lite/kernels:op_macros",
     ],
@@ -384,7 +402,7 @@
         ":cpu_check",
         ":round",
         ":types",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:activation_functor",
         "//tensorflow/contrib/lite/kernels:op_macros",
         "@arm_neon_2_x86_sse",
@@ -398,7 +416,7 @@
     hdrs = ["kernel_utils.h"],
     deps = [
         ":tensor_utils",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
     ],
 )
 
@@ -441,7 +459,7 @@
     copts = NEON_FLAGS_IF_APPLICABLE,
     deps = [
         "//tensorflow/contrib/lite/kernels:activation_functor",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "@arm_neon_2_x86_sse",
         "@gemmlowp",
     ] + select({
@@ -517,7 +535,7 @@
     ],
     deps = [
         ":tensor_utils",
-        "//tensorflow/contrib/lite:builtin_op_data",
+        "//tensorflow/contrib/lite/c:c_api_internal",
         "//tensorflow/contrib/lite/kernels:test_util",
         "@com_google_googletest//:gtest_main",
     ],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index 4a29aec..aaa3c05 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -47,7 +47,7 @@
 #endif
 #endif
 
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 360b472..56e9367 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@
 ==============================================================================*/
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 
-#include <algorithm>
-
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 
 namespace tflite {
@@ -203,9 +201,9 @@
       cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
       cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
       cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
-      projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
-      output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
-      cell_scratch, output_gate_scratch, output_ptr_batch);
+      projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
+      n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
+      forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
 }
 
 void LstmStepWithAuxInput(
@@ -227,8 +225,8 @@
     const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
     const float* output_gate_bias_ptr, const float* projection_weights_ptr,
     const float* projection_bias_ptr, const TfLiteLSTMParams* params,
-    int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
-    float* cell_state_ptr, float* input_gate_scratch,
+    int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+    float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
     float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
     float* output_ptr_batch) {
   // Since we have already checked that weights are all there or none, we can
@@ -268,19 +266,20 @@
   if (aux_input_ptr_batch != nullptr) {
     if (!use_cifg) {
       tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-          aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
-          n_batch, input_gate_scratch, /*result_stride=*/1);
+          aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+          aux_input_ptr_batch, n_batch, input_gate_scratch,
+          /*result_stride=*/1);
     }
 
     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-        aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
-        n_batch, forget_gate_scratch, /*result_stride=*/1);
+        aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+        aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-        aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
+        aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
         n_batch, cell_scratch, /*result_stride=*/1);
     tensor_utils::MatrixBatchVectorMultiplyAccumulate(
-        aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch,
-        n_batch, output_gate_scratch, /*result_stride=*/1);
+        aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+        aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
   }
 
   // For each batch and cell: compute recurrent_weight * output_state.
@@ -432,10 +431,11 @@
       cell_to_output_weights_ptr, cell_to_output_weights_scale,
       input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
       output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
-      projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
-      input_gate_scratch, forget_gate_scratch, cell_scratch,
-      output_gate_scratch, scaling_factors, product_scaling_factors,
-      recovered_cell_weights, quantized_input_ptr_batch,
+      projection_bias_ptr, params, n_batch, n_cell, n_input,
+      /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
+      cell_scratch, output_gate_scratch, scaling_factors,
+      product_scaling_factors, recovered_cell_weights,
+      quantized_input_ptr_batch,
       /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
       quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
       output_ptr_batch);
@@ -476,8 +476,9 @@
         const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
         float projection_weights_scale, const float* projection_bias_ptr,
         const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
-        int n_output, float* input_gate_scratch, float* forget_gate_scratch,
-        float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
+        int n_aux_input, int n_output, float* input_gate_scratch,
+        float* forget_gate_scratch, float* cell_scratch,
+        float* output_gate_scratch, float* scaling_factors,
         float* product_scaling_factors, float* recovered_cell_weights,
         int8_t* quantized_input_ptr_batch,
         int8_t* quantized_aux_input_ptr_batch,
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 38436c1..b5558cc 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 namespace tflite {
 namespace kernel_utils {
@@ -131,8 +131,8 @@
     const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
     const float* output_gate_bias_ptr, const float* projection_weights_ptr,
     const float* projection_bias_ptr, const TfLiteLSTMParams* params,
-    int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
-    float* cell_state_ptr, float* input_gate_scratch,
+    int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+    float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
     float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
     float* output_ptr_batch);
 
@@ -252,12 +252,13 @@
     const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
     float projection_weights_scale, const float* projection_bias_ptr,
     const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
-    int n_output, float* input_gate_scratch, float* forget_gate_scratch,
-    float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
-    float* product_scaling_factors, float* recovered_cell_weights,
-    int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch,
-    int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
-    float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch);
+    int n_aux_input, int n_output, float* input_gate_scratch,
+    float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+    float* scaling_factors, float* product_scaling_factors,
+    float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+    int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+    int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+    float* cell_state_ptr, float* output_ptr_batch);
 
 }  // namespace kernel_utils
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2..70810ca 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -1067,6 +1067,26 @@
   }
 }
 
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+                          const float* filter_data, const Dims<4>& filter_dims,
+                          const float* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
+                          float output_activation_min,
+                          float output_activation_max, float* output_data,
+                          const Dims<4>& output_dims) {
+  // TODO(suharshs): Optimized implementation of dilation depthwise conv need to
+  // be implemented.
+  TFLITE_DCHECK(dilation_width_factor == 1);
+  TFLITE_DCHECK(dilation_height_factor == 1);
+
+  DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+                bias_dims, stride_width, stride_height, pad_width, pad_height,
+                depth_multiplier, output_activation_min, output_activation_max,
+                output_data, output_dims);
+}
+
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 3fd00c8..f707279 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1964,6 +1964,30 @@
   }
 }
 
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+                          int32 input_offset, const uint8* filter_data,
+                          const Dims<4>& filter_dims, int32 filter_offset,
+                          const int32* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
+                          int32 output_offset, int32 output_multiplier,
+                          int output_shift, int32 output_activation_min,
+                          int32 output_activation_max, uint8* output_data,
+                          const Dims<4>& output_dims) {
+  // TODO(suharshs): Optimized implementation of dilation depthwise is not
+  // supported yet.
+  TFLITE_DCHECK(dilation_width_factor == 1);
+  TFLITE_DCHECK(dilation_height_factor == 1);
+
+  DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+                filter_offset, bias_data, bias_dims, stride_width,
+                stride_height, pad_width, pad_height, depth_multiplier,
+                output_offset, output_multiplier, output_shift,
+                output_activation_min, output_activation_max, output_data,
+                output_dims);
+}
+
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 921aae1..59f0e3c 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -26,7 +26,7 @@
 #include <tuple>
 #include <type_traits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -113,8 +113,8 @@
           filter_width * filter_height * input_depth;
       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
       dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
-      EigenMatrix output(output_data, 1, filter_count);
-      ConstEigenMatrix input(input_data, 1, k);
+      EigenMatrix output(output_data, input_batches, filter_count);
+      ConstEigenMatrix input(input_data, input_batches, k);
       ConstEigenMatrix filter(filter_data, k, filter_count);
       MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
                                                       filter, dim_pair);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 70b6994..2741817 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@
 #include <stdlib.h>
 #include <string.h>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/common.h"
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index e671624..630a6bb 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@
 
 // TODO(ghodrat): Remove this header file and the dependency to internal data
 // structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
 
@@ -79,6 +79,11 @@
                    n_batch, result, result_stride);
 }
 
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                          float* batch_vector) {
+  PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
 void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
                              float* batch_vector) {
   PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -138,6 +143,13 @@
                    reduction_size);
 }
 
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+                             int v_size, int n_batch,
+                             float normalization_epsilon) {
+  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+                                  normalization_epsilon);
+}
+
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 4763d77..aaf93ae 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -43,6 +43,14 @@
 // Unoptimized reference ops:
 using reference_ops::ArgMax;
 using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
 using reference_ops::BroadcastAdd4DSlow;
 using reference_ops::BroadcastGreater;
 using reference_ops::BroadcastGreaterEqual;
@@ -58,8 +66,12 @@
 using reference_ops::Gather;
 using reference_ops::Greater;
 using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
 using reference_ops::Less;
 using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
 using reference_ops::Mean;
 using reference_ops::RankOneSelect;
 using reference_ops::Relu1;
@@ -188,6 +200,8 @@
       UnalignedConstMatrix;
 };
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 // TODO(b/62193649): this function is only needed as long
 // as we have the --variable_batch hack.
 template <typename Scalar, int N>
@@ -200,6 +214,18 @@
   return MatrixMap<Scalar>(data, rows, cols);
 }
 
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+                                                   const RuntimeShape& shape,
+                                                   int rows) {
+  const int flatsize = shape.FlatSize();
+  TFLITE_DCHECK_EQ(flatsize % rows, 0);
+  const int cols = flatsize / rows;
+  return MatrixMap<Scalar>(data, rows, cols);
+}
+
 // This is like the template-parameter version, except that the power-of-two is
 // passed as a function parameter. The template version is to be preferred,
 // since some target hardware optimizations depend on the range of the exponent.
@@ -248,16 +274,16 @@
   return true;
 }
 
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
-                                             const Dims<4>& bias_dims,
-                                             float* array_data,
-                                             const Dims<4>& array_dims,
-                                             float output_activation_min,
-                                             float output_activation_max) {
+inline void AddBiasAndEvalActivationFunction(float output_activation_min,
+                                             float output_activation_max,
+                                             const RuntimeShape& bias_shape,
+                                             const float* bias_data,
+                                             const RuntimeShape& array_shape,
+                                             float* array_data) {
 #ifdef USE_NEON
   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
-  const int bias_size = FlatSize(bias_dims);
-  const int array_size = FlatSize(array_dims);
+  const int bias_size = bias_shape.FlatSize();
+  const int array_size = array_shape.FlatSize();
   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
   float* array_ptr = array_data;
   float* array_end_ptr = array_ptr + array_size;
@@ -307,8 +333,8 @@
   }
 #else  // not NEON
   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
-  const int bias_size = FlatSize(bias_dims);
-  const int array_size = FlatSize(array_dims);
+  const int bias_size = bias_shape.FlatSize();
+  const int array_size = array_shape.FlatSize();
   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
   for (int array_offset = 0; array_offset < array_size;
        array_offset += bias_size) {
@@ -321,6 +347,19 @@
 #endif
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+                                             const Dims<4>& bias_dims,
+                                             float* array_data,
+                                             const Dims<4>& array_dims,
+                                             float output_activation_min,
+                                             float output_activation_max) {
+  AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+                                   DimsToShape(bias_dims), bias_data,
+                                   DimsToShape(array_dims), array_data);
+}
+
 // Note: This to be converted to RuntimeShapes along with Conv.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
@@ -368,21 +407,24 @@
 // to a matrix*vector product. LSTM cells contain a fully-connected node;
 // when quantized, this becomes a special type of GEMV operation where
 // the output is 16bit-quantized, thus needs its own special path.
-inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
-                            const uint8* weights_data,
-                            const Dims<4>& weights_dims,
-                            uint8 weights_zero_point, const int32* bias_data,
-                            const Dims<4>& bias_dims, int32 accum_multiplier,
-                            int accum_shift, int16* output_data,
-                            const Dims<4>& output_dims) {
+inline void GEMVForLstmCell(const RuntimeShape& input_shape,
+                            const uint8* input_data,
+                            const RuntimeShape& weights_shape,
+                            const uint8* weights_data, uint8 weights_zero_point,
+                            const RuntimeShape& bias_shape,
+                            const int32* bias_data, int32 accum_multiplier,
+                            int accum_shift, const RuntimeShape& output_shape,
+                            int16* output_data) {
   gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
-  const int input_size = FlatSizeSkipDim(input_dims, 3);
-  const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+  TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+  const int input_size = FlatSizeSkipDim(input_shape, 0);
+  const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+                                      output_shape, output_dim_count - 1);
   // This special fast path for quantized LSTM cells does not try to support
   // odd sizes that we haven't encountered in any LSTM cell, that would
   // require special code (that would go untested until any LSTM cell
@@ -555,18 +597,21 @@
 
 #ifdef GEMMLOWP_NEON
 inline void GEMVForLstmCellWithSymmetricRange(
-    const uint8* input_data, const Dims<4>& input_dims,
-    const uint8* weights_data, const Dims<4>& weights_dims,
-    const int32* bias_data, const Dims<4>& bias_dims, int32 accum_multiplier,
-    int accum_shift, int16* output_data, const Dims<4>& output_dims) {
+    const RuntimeShape& input_shape, const uint8* input_data,
+    const RuntimeShape& weights_shape, const uint8* weights_data,
+    const RuntimeShape& bias_shape, const int32* bias_data,
+    int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
+    int16* output_data) {
   gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
-  const int input_size = FlatSizeSkipDim(input_dims, 3);
-  const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+  TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+  const int input_size = FlatSizeSkipDim(input_shape, 0);
+  const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+                                      output_shape, output_dim_count - 1);
   // This special fast path for quantized LSTM cells does not try to support
   // odd sizes that we haven't encountered in any LSTM cell, that would
   // require special code (that would go untested until any LSTM cell
@@ -842,14 +887,16 @@
 }
 #endif
 
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
-                           const float* weights_data,
-                           const Dims<4>& weights_dims, const float* bias_data,
-                           const Dims<4>& bias_dims,
-                           float output_activation_min,
-                           float output_activation_max, float* output_data,
-                           const Dims<4>& output_dims) {
+inline void FullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& weights_shape,
+    const float* weights_data, const RuntimeShape& bias_shape,
+    const float* bias_data, const RuntimeShape& output_shape,
+    float* output_data) {
   gemmlowp::ScopedProfilingLabel label("FullyConnected");
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+
   // TODO(b/62193649): this convoluted shape computation (determining
   // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
   // is because the current --variable_batch hack consists in overwriting the
@@ -858,18 +905,38 @@
   // When that is fixed, this should become:
   // const auto input_matrix_map =
   //     MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
-  const int input_rows = ArraySize(weights_dims, 0);
+  const int dims_count = weights_shape.DimensionsCount();
+  const int input_rows = weights_shape.Dims(dims_count - 1);
   const auto input_matrix_map =
-      MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+      MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
   const auto filter_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+      MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
   auto output_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+      MapAsMatrixWithLastDimAsRows(output_data, output_shape);
 
   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
-  AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
-                                   output_dims, output_activation_min,
-                                   output_activation_max);
+  AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+                                   bias_shape, bias_data, output_shape,
+                                   output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+                           const float* weights_data,
+                           const Dims<4>& weights_dims, const float* bias_data,
+                           const Dims<4>& bias_dims,
+                           float output_activation_min,
+                           float output_activation_max, float* output_data,
+                           const Dims<4>& output_dims) {
+  tflite::FullyConnectedParams op_params;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
+
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(weights_dims), weights_data,
+                 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+                 output_data);
 }
 
 // legacy, for compatibility with old checked-in code
@@ -887,20 +954,23 @@
 
 #ifdef USE_NEON
 inline void FullyConnectedAsGEMV(
-    const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
-    const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
-    const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+    const RuntimeShape& input_shape, const uint8* input_data,
+    int32 input_offset, const RuntimeShape& filter_shape,
+    const uint8* filter_data, int32 filter_offset,
+    const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
     int32 output_multiplier, int output_shift, int32 output_activation_min,
-    int32 output_activation_max, uint8* output_data,
-    const Dims<4>& output_dims) {
+    int32 output_activation_max, const RuntimeShape& output_shape,
+    uint8* output_data) {
   gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
-  const int input_size = FlatSizeSkipDim(input_dims, 3);
-  const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+  TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+  TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int filter_dim_count = filter_shape.DimensionsCount();
+  TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+  const int input_size = FlatSizeSkipDim(input_shape, 0);
+  const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+                                      output_shape, output_dim_count - 1);
   static constexpr int kPeel = 4;
   const bool shift_left = (output_shift <= 0);
   for (int k = 0; k < input_size; k += 64) {
@@ -1071,42 +1141,47 @@
   }
 };
 
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
-                           int32 input_offset, const uint8* filter_data,
-                           const Dims<4>& filter_dims, int32 filter_offset,
-                           const int32* bias_data, const Dims<4>& bias_dims,
-                           int32 output_offset, int32 output_multiplier,
-                           int output_shift, int32 output_activation_min,
-                           int32 output_activation_max, uint8* output_data,
-                           const Dims<4>& output_dims,
-                           gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    uint8* output_data, gemmlowp::GemmContext* gemm_context) {
   gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+  TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = FlatSizeSkipDim(output_dims, 0);
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int filter_dim_count = filter_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
 #ifdef USE_NEON
-  const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+  const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+                                      output_shape, output_dim_count - 1);
   if (batches == 1 && !(output_size % 4)) {
     return FullyConnectedAsGEMV(
-        input_data, input_dims, input_offset, filter_data, filter_dims,
-        filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
-        output_shift, output_activation_min, output_activation_max, output_data,
-        output_dims);
+        input_shape, input_data, input_offset, filter_shape, filter_data,
+        filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
+        output_shift, output_activation_min, output_activation_max,
+        output_shape, output_data);
   }
 #endif  // USE_NEON
-  const int filter_rows = filter_dims.sizes[1];
-  const int filter_cols = filter_dims.sizes[0];
-  TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
-  TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
-  const int output_rows = output_dims.sizes[0];
+  const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
+  const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
+  TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
+  const int output_rows = output_shape.Dims(output_dim_count - 1);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
 
   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
       filter_data, output_rows, filter_cols, filter_cols);
@@ -1123,30 +1198,65 @@
       input_offset, output_pipeline);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+                           int32 input_offset, const uint8* filter_data,
+                           const Dims<4>& filter_dims, int32 filter_offset,
+                           const int32* bias_data, const Dims<4>& bias_dims,
+                           int32 output_offset, int32 output_multiplier,
+                           int output_shift, int32 output_activation_min,
+                           int32 output_activation_max, uint8* output_data,
+                           const Dims<4>& output_dims,
+                           gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  op_params.output_shift = output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                 bias_data, DimsToShape(output_dims), output_data,
+                 gemm_context);
+}
+
 inline void FullyConnected(
-    const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
-    const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
-    const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
-    int32 output_multiplier, int output_shift, int32 output_activation_min,
-    int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
-    gemmlowp::GemmContext* gemm_context) {
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& filter_shape,
+    const uint8* filter_data, const RuntimeShape& bias_shape,
+    const int32* bias_data_int32, const RuntimeShape& output_shape,
+    int16* output_data, gemmlowp::GemmContext* gemm_context) {
   gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
   // This is a copy of the reference implementation. We do not currently have a
   // properly optimized version.
   (void)gemm_context;  // only used in properly optimized code.
   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   TFLITE_DCHECK_EQ(output_offset, 0);
+  TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
 
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = FlatSizeSkipDim(output_dims, 0);
-  const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(filter_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int filter_dim_count = filter_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+  const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+                                       output_shape, output_dim_count - 1);
+  const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
 
   // Implementation of the fully connected node suited to the inside of an LSTM
   // cell. The operands are 8-bit integers, the accumulators are internally
@@ -1157,17 +1267,17 @@
   if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
       output_activation_max == 32767) {
     if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
-      GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
-                                        filter_dims, bias_data_int32, bias_dims,
-                                        output_multiplier, -output_shift,
-                                        output_data, output_dims);
+      GEMVForLstmCellWithSymmetricRange(
+          input_shape, input_data, filter_shape, filter_data, bias_shape,
+          bias_data_int32, output_multiplier, -output_shift, output_shape,
+          output_data);
       return;
     }
     if (!(output_depth % 4) && !(accum_depth % 8)) {
-      GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
-                      filter_offset, bias_data_int32, bias_dims,
-                      output_multiplier, -output_shift, output_data,
-                      output_dims);
+      GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
+                      filter_offset, bias_shape, bias_data_int32,
+                      output_multiplier, -output_shift, output_shape,
+                      output_data);
       return;
     }
   }
@@ -1201,6 +1311,31 @@
       input_offset, output_pipeline);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(
+    const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+    const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+    const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+    int32 output_multiplier, int output_shift, int32 output_activation_min,
+    int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+    gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  op_params.output_shift = output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  FullyConnected(op_params, DimsToShape(input_dims), input_data,
+                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+                 bias_data_int32, DimsToShape(output_dims), output_data,
+                 gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
@@ -1543,26 +1678,34 @@
 };
 
 inline void ShuffledFullyConnected(
-    const uint8* input_data, const Dims<4>& input_dims,
-    const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
-    const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
-    int output_shift, int32 output_activation_min, int32 output_activation_max,
-    int16* output_data, const Dims<4>& output_dims,
-    uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+    const FullyConnectedParams& params, const RuntimeShape& input_shape,
+    const uint8* input_data, const RuntimeShape& weights_shape,
+    const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+    const int32* bias_data, const RuntimeShape& output_shape,
+    int16* output_data, uint8* shuffled_input_workspace_data,
+    gemmlowp::GemmContext* gemm_context) {
   gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
   (void)gemm_context;  // only used in optimized code.
   TFLITE_DCHECK_EQ(output_activation_min, -32768);
   TFLITE_DCHECK_EQ(output_activation_max, 32767);
+  TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+  TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+  TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
   // TODO(benoitjacob): This really should be:
   //     const int batches = ArraySize(output_dims, 1);
   // but the current --variable_batch hack consists in overwriting the 3rd
   // dimension with the runtime batch size, as we don't keep track for each
   // array of which dimension is the batch dimension in it.
-  const int batches = FlatSizeSkipDim(output_dims, 0);
-  const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
-  const int accum_depth = ArraySize(weights_dims, 0);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+  const int output_dim_count = output_shape.DimensionsCount();
+  const int weights_dim_count = weights_shape.DimensionsCount();
+  const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+  const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+                                       output_shape, output_dim_count - 1);
+  const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
   TFLITE_DCHECK((accum_depth % 16) == 0);
   TFLITE_DCHECK((output_depth % 4) == 0);
   // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
@@ -1659,13 +1802,39 @@
   gemm_context->workers_pool()->Execute(tasks);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+    const uint8* input_data, const Dims<4>& input_dims,
+    const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+    const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+    int output_shift, int32 output_activation_min, int32 output_activation_max,
+    int16* output_data, const Dims<4>& output_dims,
+    uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+  tflite::FullyConnectedParams op_params;
+  op_params.output_multiplier = output_multiplier;
+  op_params.output_shift = output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+                         DimsToShape(weights_dims), shuffled_weights_data,
+                         DimsToShape(bias_dims), bias_data,
+                         DimsToShape(output_dims), output_data,
+                         shuffled_input_workspace_data, gemm_context);
+}
+
 template <typename T>
-inline void ExtractPatchIntoBufferColumn(
-    const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
-    int stride_width, int stride_height, int pad_width, int pad_height,
-    int in_width, int in_height, int in_depth, int single_buffer_length,
-    int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
+                                         int h, int b, int kheight, int kwidth,
+                                         int stride_width, int stride_height,
+                                         int pad_width, int pad_height,
+                                         int in_width, int in_height,
+                                         int in_depth, int single_buffer_length,
+                                         int buffer_id, const T* in_data,
+                                         T* conv_buffer_data, uint8 zero_byte) {
   gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
   // This chunk of code reshapes all the inputs corresponding to
   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
   const int kwidth_times_indepth = kwidth * in_depth;
@@ -1687,7 +1856,7 @@
   const int output_row_offset = (buffer_id * single_buffer_length);
   int out_offset =
       output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
-  int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+  int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
 
   // Express all of the calculations as padding around the input patch.
   const int top_padding = h_offset;
@@ -1701,7 +1870,7 @@
   // patch that are off the edge of the input image.
   if (top_padding > 0) {
     const int top_row_elements = (top_padding * kwidth * in_depth);
-    memset(conv_buffer_data + output_row_offset, byte_zero,
+    memset(conv_buffer_data + output_row_offset, zero_byte,
            (top_row_elements * sizeof(T)));
   }
 
@@ -1718,14 +1887,14 @@
     for (int ih = ih_start; ih < ih_end; ++ih) {
       if (left_padding > 0) {
         const int left_start = (out_offset - (left_padding * in_depth));
-        memset(conv_buffer_data + left_start, byte_zero,
+        memset(conv_buffer_data + left_start, zero_byte,
                (left_padding * in_depth * sizeof(T)));
       }
       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
              single_row_num * sizeof(T));
       if (right_padding > 0) {
         const int right_start = (out_offset + single_row_num);
-        memset(conv_buffer_data + right_start, byte_zero,
+        memset(conv_buffer_data + right_start, zero_byte,
                (right_padding * in_depth * sizeof(T)));
       }
       out_offset += kwidth_times_indepth;
@@ -1740,61 +1909,64 @@
     const int bottom_start =
         output_row_offset +
         ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
-    memset(conv_buffer_data + bottom_start, byte_zero,
+    memset(conv_buffer_data + bottom_start, zero_byte,
            (bottom_row_elements * sizeof(T)));
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
-                   const Dims<4>& filter_dims, int stride_width,
-                   int stride_height, int dilation_width_factor,
-                   int dilation_height_factor, int pad_width, int pad_height,
-                   const Dims<4>& output_dims, uint8 byte_zero,
-                   T* im2col_data) {
+inline void ExtractPatchIntoBufferColumn(
+    const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+    int stride_width, int stride_height, int pad_width, int pad_height,
+    int in_width, int in_height, int in_depth, int single_buffer_length,
+    int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+  ExtractPatchIntoBufferColumn(
+      DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+      stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+      single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
+                   const RuntimeShape& input_shape, const T* input_data,
+                   const RuntimeShape& filter_shape,
+                   const RuntimeShape& output_shape, T* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
   // For dilated convolution, the input pixels are not contiguous therefore we
   // can't use the same opitimizations as Im2Col(). Though note this code would
   // work fine for the non-dilated case too (though likely a bit slower).
   gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
   TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   TFLITE_DCHECK(im2col_data);
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  MatchingArraySize(output_dims, 0, filter_dims, 3);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  MatchingDim(output_shape, 3, filter_shape, 0);
 
   // Construct the MxN sized im2col matrix.
   // The rows M, are sub-ordered B x H x W
-  Dims<4> row_dims;
-  row_dims.sizes[0] = output_width;
-  row_dims.sizes[1] = output_height;
-  row_dims.sizes[2] = batches;
-  row_dims.sizes[3] = 1;
-  ComputeStrides(&row_dims);
-
+  const RuntimeShape row_shape({1, batches, output_height, output_width});
   // The columns, N, are sub-ordered Kh x Kw x Din
-  Dims<4> col_dims;
-  col_dims.sizes[0] = input_depth;
-  col_dims.sizes[1] = filter_width;
-  col_dims.sizes[2] = filter_height;
-  col_dims.sizes[3] = 1;
-  ComputeStrides(&col_dims);
-
+  const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
   // Use dimensions M and N to construct dims for indexing directly into im2col
-  Dims<4> im2col_dims;
-  im2col_dims.sizes[0] = FlatSize(col_dims);
-  im2col_dims.sizes[1] = FlatSize(row_dims);
-  im2col_dims.sizes[2] = 1;
-  im2col_dims.sizes[3] = 1;
-  ComputeStrides(&im2col_dims);
+  const RuntimeShape im2col_shape(
+      {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
 
   // Loop through the output rows (B x H x W)
   for (int batch = 0; batch < batches; ++batch) {
@@ -1802,7 +1974,7 @@
       for (int out_x = 0; out_x < output_width; ++out_x) {
         // Each im2col row is an output pixel. Arrange the input data in this
         // row in an order we can conveniently multiply with the filter data.
-        int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
+        int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
         const int in_x_origin = (out_x * stride_width) - pad_width;
         const int in_y_origin = (out_y * stride_height) - pad_height;
         // Loop through all the pixels of the filter (Kh x Kw)
@@ -1813,25 +1985,25 @@
             // Loop through all the filter pixels in this row.
             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
               const int in_x = in_x_origin + dilation_width_factor * filter_x;
-              int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0);
+              int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
               T* dst = im2col_data +
-                       Offset(im2col_dims, col_offset, row_offset, 0, 0);
+                       Offset(im2col_shape, 0, 0, row_offset, col_offset);
               if ((in_x >= 0) && (in_x < input_width)) {
                 // Filter pixel is within the input, copy the input data.
                 T const* src =
-                    input_data + Offset(input_dims, 0, in_x, in_y, batch);
+                    input_data + Offset(input_shape, batch, in_y, in_x, 0);
                 memcpy(dst, src, input_depth * sizeof(T));
               } else {
                 // Filter pixel is outside the input, zero it out.
-                memset(dst, byte_zero, input_depth * sizeof(T));
+                memset(dst, zero_byte, input_depth * sizeof(T));
               }
             }
           } else {
             // Filter row is outside the input, zero out the entire filter row.
-            int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
-            T* dst =
-                im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
-            memset(dst, byte_zero, filter_width * input_depth * sizeof(T));
+            int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
+            T* dst = im2col_data +
+                     Offset(im2col_shape, 0, 0, row_offset, col_offset);
+            memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
           }
         }
       }
@@ -1839,21 +2011,49 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
-            int stride_height, int pad_width, int pad_height, int kheight,
-            int kwidth, uint8 byte_zero, T* output_data,
-            const Dims<4>& output_dims) {
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+                   const Dims<4>& filter_dims, int stride_width,
+                   int stride_height, int dilation_width_factor,
+                   int dilation_height_factor, int pad_width, int pad_height,
+                   const Dims<4>& output_dims, uint8 zero_byte,
+                   T* im2col_data) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+
+  DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), DimsToShape(output_dims),
+                im2col_data);
+}
+
+template <typename T>
+void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
+            const RuntimeShape& input_shape, const T* input_data,
+            const RuntimeShape& output_shape, T* output_data) {
   gemmlowp::ScopedProfilingLabel label("Im2col");
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_depth = ArraySize(input_dims, 0);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_height = ArraySize(input_dims, 2);
-  const int output_depth = ArraySize(output_dims, 0);
-  const int output_width = ArraySize(output_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int input_width = input_shape.Dims(2);
+  const int input_height = input_shape.Dims(1);
+  const int output_depth = output_shape.Dims(3);
+  const int output_width = output_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
 
   int buffer_id = 0;
   // Loop over the output nodes.
@@ -1861,24 +2061,110 @@
     for (int h = 0; h < output_height; ++h) {
       for (int w = 0; w < output_width; ++w) {
         ExtractPatchIntoBufferColumn(
-            input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+            input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
             pad_width, pad_height, input_width, input_height, input_depth,
-            output_depth, buffer_id, input_data, output_data, byte_zero);
+            output_depth, buffer_id, input_data, output_data, zero_byte);
         ++buffer_id;
       }
     }
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+            int stride_height, int pad_width, int pad_height, int kheight,
+            int kwidth, uint8 zero_byte, T* output_data,
+            const Dims<4>& output_dims) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = 1;
+  op_params.dilation_height_factor = 1;
+
+  Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+         input_data, DimsToShape(output_dims), output_data);
+}
+
 // legacy, for compatibility with old checked-in code
 template <typename T>
 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
             int pad_width, int pad_height, int kheight, int kwidth,
-            uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+            uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
-         kwidth, byte_zero, output_data, output_dims);
+         kwidth, zero_byte, output_data, output_dims);
 }
 
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+                 const float* input_data, const RuntimeShape& filter_shape,
+                 const float* filter_data, const RuntimeShape& bias_shape,
+                 const float* bias_data, const RuntimeShape& output_shape,
+                 float* output_data, const RuntimeShape& im2col_shape,
+                 float* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+  (void)im2col_data;
+  (void)im2col_shape;
+  gemmlowp::ScopedProfilingLabel label("Conv");
+
+  // NB: static_cast<float>(0x00000000h) == 0.0f
+  const uint8 float_zero_byte = 0x00;
+  const float* gemm_input_data = nullptr;
+  const RuntimeShape* gemm_input_shape = nullptr;
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const bool need_dilated_im2col =
+      dilation_width_factor != 1 || dilation_height_factor != 1;
+  const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+                           filter_width != 1 || filter_height != 1;
+  if (need_dilated_im2col) {
+    DilatedIm2col(params, float_zero_byte, input_shape, input_data,
+                  filter_shape, output_shape, im2col_data);
+    gemm_input_data = im2col_data;
+    gemm_input_shape = &im2col_shape;
+  } else if (need_im2col) {
+    TFLITE_DCHECK(im2col_data);
+    Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
+           input_data, im2col_shape, im2col_data);
+    gemm_input_data = im2col_data;
+    gemm_input_shape = &im2col_shape;
+  } else {
+    // TODO(aselle): We need to make sure to not send im2col if it is not
+    // needed.
+    TFLITE_DCHECK(!im2col_data);
+    gemm_input_data = input_data;
+    gemm_input_shape = &input_shape;
+  }
+
+  const auto im2col_matrix_map =
+      MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
+  const auto filter_matrix_map =
+      MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
+  auto output_matrix_map =
+      MapAsMatrixWithLastDimAsRows(output_data, output_shape);
+
+  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+
+  AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+                                   bias_shape, bias_data, output_shape,
+                                   output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void Conv(const float* input_data, const Dims<4>& input_dims,
                  const float* filter_data, const Dims<4>& filter_dims,
                  const float* bias_data, const Dims<4>& bias_dims,
@@ -1887,67 +2173,43 @@
                  float output_activation_min, float output_activation_max,
                  float* output_data, const Dims<4>& output_dims,
                  float* im2col_data, const Dims<4>& im2col_dims) {
-  (void)im2col_data;
-  (void)im2col_dims;
-  gemmlowp::ScopedProfilingLabel label("Conv");
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
 
-  // NB: static_cast<float>(0x00000000h) == 0.0f
-  const uint8 float_zero_byte = 0x00;
-  const float* gemm_input_data = nullptr;
-  const Dims<4>* gemm_input_dims = nullptr;
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const bool need_dilated_im2col =
-      dilation_width_factor != 1 || dilation_height_factor != 1;
-  const bool need_im2col = stride_width != 1 || stride_height != 1 ||
-                           filter_width != 1 || filter_height != 1;
-  if (need_dilated_im2col) {
-    DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
-                  stride_height, dilation_width_factor, dilation_height_factor,
-                  pad_width, pad_height, output_dims, float_zero_byte,
-                  im2col_data);
-    gemm_input_data = im2col_data;
-    gemm_input_dims = &im2col_dims;
-  } else if (need_im2col) {
-    TFLITE_DCHECK(im2col_data);
-    Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
-           pad_height, filter_height, filter_width, float_zero_byte,
-           im2col_data, im2col_dims);
-    gemm_input_data = im2col_data;
-    gemm_input_dims = &im2col_dims;
-  } else {
-    // TODO(aselle): We need to make sure to not send im2col if it is not
-    // needed.
-    TFLITE_DCHECK(!im2col_data);
-    gemm_input_data = input_data;
-    gemm_input_dims = &input_dims;
-  }
-
-  const auto im2col_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
-  const auto filter_matrix_map =
-      MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
-  auto output_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-
-  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
-
-  AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
-                                   output_dims, output_activation_min,
-                                   output_activation_max);
+  Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+       output_data, DimsToShape(im2col_dims), im2col_data);
 }
 
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
-                       const int8_t* filter_data, const Dims<4>& filter_dims,
-                       const float* bias_data, const Dims<4>& bias_dims,
-                       int stride_width, int stride_height, int pad_width,
-                       int pad_height, float* scaling_factors_ptr,
-                       float output_activation_min, float output_activation_max,
-                       float* output_data, const Dims<4>& output_dims,
-                       int8_t* im2col_data, const Dims<4>& im2col_dims) {
-  const int batch_size = input_dims.sizes[3];
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
+inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
+                       const RuntimeShape& input_shape,
+                       const int8_t* input_data,
+                       const RuntimeShape& filter_shape,
+                       const int8_t* filter_data,
+                       const RuntimeShape& bias_shape, const float* bias_data,
+                       const RuntimeShape& output_shape, float* output_data,
+                       const RuntimeShape& im2col_shape, int8_t* im2col_data) {
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const float output_activation_min = params.float_activation_min;
+  const float output_activation_max = params.float_activation_max;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
+
+  const int batch_size = input_shape.Dims(0);
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
 
   const int8_t* gemm_input_data = nullptr;
   int num_input;
@@ -1958,25 +2220,22 @@
     TFLITE_DCHECK(im2col_data);
     // symmetric quantization assumes zero point of 0.
     const int input_zero_point = 0;
-    Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
-           pad_height, filter_height, filter_width, input_zero_point,
-           im2col_data, im2col_dims);
+
+    Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+           input_data, im2col_shape, im2col_data);
     gemm_input_data = im2col_data;
-    num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
-                im2col_dims.sizes[2] * im2col_dims.sizes[3];
+    num_input = im2col_shape.FlatSize();
   } else {
     TFLITE_DCHECK(!im2col_data);
     gemm_input_data = input_data;
-    num_input = input_dims.sizes[0] * input_dims.sizes[1] *
-                input_dims.sizes[2] * input_dims.sizes[3];
+    num_input = input_shape.FlatSize();
   }
 
   // Flatten 4D matrices into 2D matrices for matrix multiplication.
 
   // Flatten so that each filter has its own row.
-  const int filter_rows = filter_dims.sizes[3];
-  const int filter_cols =
-      filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+  const int filter_rows = filter_shape.Dims(0);
+  const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
 
   // In MatrixBatchVectorMultiplyAccumulate, each output value is the
   // dot product of one row of the first matrix with one row of the second
@@ -1986,15 +2245,14 @@
   const int gemm_input_cols = filter_cols;
   const int gemm_input_rows = num_input / gemm_input_cols;
 
-  const int output_cols = output_dims.sizes[0];
-  const int output_rows =
-      output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+  const int output_cols = output_shape.Dims(3);
+  const int output_rows = FlatSizeSkipDim(output_shape, 3);
   TFLITE_DCHECK_EQ(output_cols, filter_rows);
   TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+  TFLITE_DCHECK_EQ(bias_shape.Dims(3), output_cols);
+  TFLITE_DCHECK_EQ(bias_shape.Dims(2), 1);
+  TFLITE_DCHECK_EQ(bias_shape.Dims(1), 1);
+  TFLITE_DCHECK_EQ(bias_shape.Dims(0), 1);
 
   // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
   // input matrix has its own scale factor. This code duplicates the scale
@@ -2011,11 +2269,39 @@
       scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
       /*result_stride=*/1);
 
-  AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
-                                   output_dims, output_activation_min,
-                                   output_activation_max);
+  AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+                                   bias_shape, bias_data, output_shape,
+                                   output_data);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+                       const int8_t* filter_data, const Dims<4>& filter_dims,
+                       const float* bias_data, const Dims<4>& bias_dims,
+                       int stride_width, int stride_height, int pad_width,
+                       int pad_height, float* scaling_factors_ptr,
+                       float output_activation_min, float output_activation_max,
+                       float* output_data, const Dims<4>& output_dims,
+                       int8_t* im2col_data, const Dims<4>& im2col_dims) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.float_activation_min = output_activation_min;
+  op_params.float_activation_max = output_activation_max;
+
+  HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+             input_data, DimsToShape(filter_dims), filter_data,
+             DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+             output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
           const float* filter_data, const Dims<4>& filter_dims,
@@ -2033,6 +2319,7 @@
        im2col_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2049,6 +2336,7 @@
        im2col_data, im2col_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2062,27 +2350,33 @@
            output_dims, im2col_data, im2col_dims);
 }
 
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
-                 int32 input_offset, const uint8* filter_data,
-                 const Dims<4>& filter_dims, int32 filter_offset,
-                 const int32* bias_data, const Dims<4>& bias_dims,
-                 int stride_width, int stride_height, int dilation_width_factor,
-                 int dilation_height_factor, int pad_width, int pad_height,
-                 int32 output_offset, int32 output_multiplier, int output_shift,
-                 int32 output_activation_min, int32 output_activation_max,
-                 uint8* output_data, const Dims<4>& output_dims,
-                 uint8* im2col_data, const Dims<4>& im2col_dims,
-                 gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+                 const uint8* input_data, const RuntimeShape& filter_shape,
+                 const uint8* filter_data, const RuntimeShape& bias_shape,
+                 const int32* bias_data, const RuntimeShape& output_shape,
+                 uint8* output_data, const RuntimeShape& im2col_shape,
+                 uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
   gemmlowp::ScopedProfilingLabel label("Conv/8bit");
-
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int32 input_offset = params.input_offset;
+  const int32 filter_offset = params.weights_offset;
+  const int32 output_offset = params.output_offset;
+  const int32 output_multiplier = params.output_multiplier;
+  const int output_shift = params.output_shift;
+  const int32 output_activation_min = params.quantized_activation_min;
+  const int32 output_activation_max = params.quantized_activation_max;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 4);
 
   const uint8* gemm_input_data = nullptr;
-  const Dims<4>* gemm_input_dims = nullptr;
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int filter_height = ArraySize(filter_dims, 2);
+  const RuntimeShape* gemm_input_shape = nullptr;
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
   const bool need_dilated_im2col =
       dilation_width_factor != 1 || dilation_height_factor != 1;
   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
@@ -2092,53 +2386,47 @@
     const int input_zero_point = -input_offset;
     TFLITE_DCHECK_GE(input_zero_point, 0);
     TFLITE_DCHECK_LE(input_zero_point, 255);
-    DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
-                  stride_height, dilation_width_factor, dilation_height_factor,
-                  pad_width, pad_height, output_dims, input_zero_point,
-                  im2col_data);
+    DilatedIm2col(params, input_zero_point, input_shape, input_data,
+                  filter_shape, output_shape, im2col_data);
     gemm_input_data = im2col_data;
-    gemm_input_dims = &im2col_dims;
+    gemm_input_shape = &im2col_shape;
   } else if (need_im2col) {
     TFLITE_DCHECK(im2col_data);
     const int input_zero_point = -input_offset;
     TFLITE_DCHECK_GE(input_zero_point, 0);
     TFLITE_DCHECK_LE(input_zero_point, 255);
-    Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
-           pad_height, filter_height, filter_width, input_zero_point,
-           im2col_data, im2col_dims);
+    Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+           input_data, im2col_shape, im2col_data);
     gemm_input_data = im2col_data;
-    gemm_input_dims = &im2col_dims;
+    gemm_input_shape = &im2col_shape;
   } else {
     TFLITE_DCHECK(!im2col_data);
     gemm_input_data = input_data;
-    gemm_input_dims = &input_dims;
+    gemm_input_shape = &input_shape;
   }
 
-  const int gemm_input_rows = gemm_input_dims->sizes[0];
+  const int gemm_input_rows = gemm_input_shape->Dims(3);
   // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
   // The root cause has not yet been identified though. Same applies below for
   // the other calls commented out. This is a partial rollback of cl/196819423.
-  // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
-  const int gemm_input_cols = gemm_input_dims->sizes[1] *
-                              gemm_input_dims->sizes[2] *
-                              gemm_input_dims->sizes[3];
-  const int filter_rows = filter_dims.sizes[3];
+  // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
+  const int gemm_input_cols = gemm_input_shape->Dims(0) *
+                              gemm_input_shape->Dims(1) *
+                              gemm_input_shape->Dims(2);
+  const int filter_rows = filter_shape.Dims(0);
   // See b/79927784.
-  // const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+  // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
   const int filter_cols =
-      filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
-  const int output_rows = output_dims.sizes[0];
+      filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
+  const int output_rows = output_shape.Dims(3);
   // See b/79927784.
-  // const int output_cols = FlatSizeSkipDim(output_dims, 0);
+  // const int output_cols = FlatSizeSkipDim(output_shape, 3);
   const int output_cols =
-      output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+      output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
   TFLITE_DCHECK_EQ(output_rows, filter_rows);
   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
-  TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+  TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
       filter_data, filter_rows, filter_cols);
   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
@@ -2154,6 +2442,43 @@
       input_offset, output_pipeline);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+                 int32 input_offset, const uint8* filter_data,
+                 const Dims<4>& filter_dims, int32 filter_offset,
+                 const int32* bias_data, const Dims<4>& bias_dims,
+                 int stride_width, int stride_height, int dilation_width_factor,
+                 int dilation_height_factor, int pad_width, int pad_height,
+                 int32 output_offset, int32 output_multiplier, int output_shift,
+                 int32 output_activation_min, int32 output_activation_max,
+                 uint8* output_data, const Dims<4>& output_dims,
+                 uint8* im2col_data, const Dims<4>& im2col_dims,
+                 gemmlowp::GemmContext* gemm_context) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+  op_params.dilation_width_factor = dilation_width_factor;
+  op_params.dilation_height_factor = dilation_height_factor;
+  op_params.input_offset = input_offset;
+  op_params.weights_offset = filter_offset;
+  op_params.output_offset = output_offset;
+  op_params.output_multiplier = output_multiplier;
+  op_params.output_shift = output_shift;
+  op_params.quantized_activation_min = output_activation_min;
+  op_params.quantized_activation_max = output_activation_max;
+
+  Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+       filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+       output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
                  int32 input_offset, const uint8* filter_data,
                  const Dims<4>& filter_dims, int32 filter_offset,
@@ -2172,6 +2497,7 @@
        im2col_data, im2col_dims, gemm_context);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2201,6 +2527,7 @@
        im2col_data, im2col_dims, gemm_context);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2224,13 +2551,14 @@
        im2col_data, im2col_dims, gemm_context);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac, typename T>
 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
             int pad_width, int pad_height, int kheight, int kwidth,
-            uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+            uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
-         kwidth, byte_zero, output_data, output_dims);
+         kwidth, zero_byte, output_data, output_dims);
 }
 
 // legacy, for compatibility with old checked-in code
@@ -2254,6 +2582,7 @@
                                        output_dims);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
 // legacy, for compatibility with old checked-in code
 template <FusedActivationFunctionType Ac>
 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
@@ -2308,9 +2637,9 @@
 
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int input_depth = input_shape.Dims(3);
@@ -2349,9 +2678,9 @@
 
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int output_depth = output_shape.Dims(3);
@@ -3179,7 +3508,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
@@ -3446,10 +3775,11 @@
   bool gemm_already_performed = false;
 #ifdef GEMMLOWP_NEON
   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
-    GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims,
-                    weights_data_uint8, weights_dims, weights_zero_point,
-                    bias_data_int32, bias_dims, accum_multiplier, accum_shift,
-                    activ_temp_data_int16, activ_temp_dims);
+    GEMVForLstmCell(DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+                    DimsToShape(weights_dims), weights_data_uint8,
+                    weights_zero_point, DimsToShape(bias_dims), bias_data_int32,
+                    accum_multiplier, accum_shift, DimsToShape(activ_temp_dims),
+                    activ_temp_data_int16);
     gemm_already_performed = true;
   }
 #endif
@@ -5430,9 +5760,9 @@
   gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5479,9 +5809,9 @@
   gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5540,9 +5870,9 @@
 
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input1_shape =
+  const RuntimeShape input1_shape =
       RuntimeShape::ExtendedShape(4, unextended_input1_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int output_width = output_shape.Dims(2);
@@ -5626,8 +5956,10 @@
                     const P* pad_value_ptr, const RuntimeShape& output_shape,
                     T* output_data) {
   gemmlowp::ScopedProfilingLabel label("Pad");
-  RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
-  RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+  const RuntimeShape ext_input_shape =
+      RuntimeShape::ExtendedShape(4, input_shape);
+  const RuntimeShape ext_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
 
@@ -5759,7 +6091,7 @@
                   const RuntimeShape& input_shape, const T* input_data,
                   const RuntimeShape& output_shape, T* output_data) {
   gemmlowp::ScopedProfilingLabel label("Slice");
-  RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+  const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
   // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
   TFLITE_DCHECK_LE(op_params.begin_count, 4);
   TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -5820,58 +6152,45 @@
 }
 
 template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
-                     const Dims<4>& filter_dims, int stride_width,
-                     int stride_height, int pad_width, int pad_height,
-                     const Dims<4>& output_dims, uint8 zero_byte,
-                     T* im2col_data) {
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+                     const RuntimeShape& input_shape, const T* input_data,
+                     const RuntimeShape& filter_shape,
+                     const RuntimeShape& output_shape, T* im2col_data) {
   gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
-  TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
-  TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
   TFLITE_DCHECK(im2col_data);
 
-  const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
-  const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
-  const int filter_height = ArraySize(filter_dims, 2);
-  const int filter_width = ArraySize(filter_dims, 1);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  MatchingArraySize(output_dims, 0, filter_dims, 0);  // output_depth
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  MatchingDim(output_shape, 3, filter_shape, 3);  // output_depth
 
   // Construct the MxN sized im2col matrix.
   // The rows M, are sub-ordered B x H x W
-  Dims<4> row_dims;
-  row_dims.sizes[0] = output_width;
-  row_dims.sizes[1] = output_height;
-  row_dims.sizes[2] = batches;
-  row_dims.sizes[3] = 1;
-  ComputeStrides(&row_dims);
-
+  const RuntimeShape row_shape({1, batches, output_height, output_width});
   // The columns, N, are sub-ordered Kh x Kw x Din
-  Dims<4> col_dims;
-  col_dims.sizes[0] = input_depth;
-  col_dims.sizes[1] = filter_width;
-  col_dims.sizes[2] = filter_height;
-  col_dims.sizes[3] = 1;
-  ComputeStrides(&col_dims);
-
+  const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
   // Use dimensions M and N to construct dims for indexing directly into im2col
-  Dims<4> im2col_dims;
-  im2col_dims.sizes[0] = FlatSize(col_dims);
-  im2col_dims.sizes[1] = FlatSize(row_dims);
-  im2col_dims.sizes[2] = 1;
-  im2col_dims.sizes[3] = 1;
-  ComputeStrides(&im2col_dims);
+  const RuntimeShape im2col_shape(
+      {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
 
   // Build the im2col matrix by looping through all the input pixels,
   // computing their influence on the output, rather than looping through all
   // the output pixels. We therefore must initialize the im2col array to zero.
   // This is potentially inefficient because we subsequently overwrite bytes
   // set here. However, in practice memset is very fast and costs negligible.
-  memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+  memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
 
   // Loop through the output batches
   for (int batch = 0; batch < batches; ++batch) {
@@ -5891,11 +6210,11 @@
               if ((out_x >= 0) && (out_x < output_width)) {
                 // Copy the input elements of this pixel
                 T const* src =
-                    input_data + Offset(input_dims, 0, in_x, in_y, batch);
+                    input_data + Offset(input_shape, batch, in_y, in_x, 0);
+                int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+                int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
                 T* dst = im2col_data +
-                         Offset(im2col_dims,
-                                Offset(col_dims, 0, filter_x, filter_y, 0),
-                                Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+                         Offset(im2col_shape, 0, 0, row_offset, col_offset);
                 memcpy(dst, src, input_depth * sizeof(T));
               }
             }
@@ -5906,29 +6225,69 @@
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+                     const Dims<4>& filter_dims, int stride_width,
+                     int stride_height, int pad_width, int pad_height,
+                     const Dims<4>& output_dims, uint8 zero_byte,
+                     T* im2col_data) {
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
+
+  TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+                  DimsToShape(filter_dims), DimsToShape(output_dims),
+                  im2col_data);
+}
+
+inline void TransposeConv(
+    const ConvParams& params, const RuntimeShape& input_shape,
+    const float* input_data, const RuntimeShape& filter_shape,
+    const float* filter_data, const RuntimeShape& output_shape,
+    float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+  gemmlowp::ScopedProfilingLabel label("TransposeConv");
+
+  // Note we could use transposed weights with forward conv for unstrided
+  // cases. But we are already getting good performance with this code as-is.
+  TFLITE_DCHECK(im2col_data);
+  TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+                  output_shape, im2col_data);
+
+  const auto im2col_matrix_map =
+      MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
+  const auto filter_matrix_map =
+      MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
+  auto output_matrix_map =
+      MapAsMatrixWithLastDimAsRows(output_data, output_shape);
+
+  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
                           const float* filter_data, const Dims<4>& filter_dims,
                           int stride_width, int stride_height, int pad_width,
                           int pad_height, float* output_data,
                           const Dims<4>& output_dims, float* im2col_data,
                           const Dims<4>& im2col_dims) {
-  gemmlowp::ScopedProfilingLabel label("TransposeConv");
+  tflite::ConvParams op_params;
+  // Padding type is ignored, but still set.
+  op_params.padding_type = PaddingType::kSame;
+  op_params.padding_values.width = pad_width;
+  op_params.padding_values.height = pad_height;
+  op_params.stride_width = stride_width;
+  op_params.stride_height = stride_height;
 
-  // Note we could use transposed weights with forward conv for unstrided
-  // cases. But we are already getting good performance with this code as-is.
-  TFLITE_DCHECK(im2col_data);
-  TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
-                  stride_height, pad_width, pad_height, output_dims, 0,
-                  im2col_data);
-
-  const auto im2col_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
-  const auto filter_matrix_map =
-      MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
-  auto output_matrix_map =
-      MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-
-  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+  TransposeConv(op_params, DimsToShape(input_dims), input_data,
+                DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+                output_data, DimsToShape(im2col_dims), im2col_data);
 }
 
 }  // namespace optimized_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 8664ebc..f87760a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,7 @@
 
 // TODO(ghodrat): Remove this header file and the dependency to internal data
 // structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 #if defined(_MSC_VER)
 #define __restrict__ __restrict
@@ -117,6 +117,10 @@
 void NeonClipVector(const float* vector, int v_size, float abs_limit,
                     float* result);
 
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                                  float* batch_vector);
+
 // Batch vector initialization with another vector.
 void PortableVectorBatchVectorAssign(const float* vector, int v_size,
                                      int n_batch, float* batch_vector);
@@ -172,6 +176,10 @@
 void NeonReductionSumVector(const float* input_vector, float* output_vector,
                             int output_size, int reduction_size);
 
+void PortableMeanStddevNormalization(const float* input_vector,
+                                     float* output_vector, int v_size,
+                                     int n_batch, float normalization_epsilon);
+
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index f882f99..544ef16 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -23,6 +23,32 @@
 
 namespace tflite {
 
+namespace {
+// These constants are used to manipulate the binary representation of doubles.
+// Double-precision binary64 floating point format is:
+// Bit |  63  |  62-52   |   51-0   |
+//     | Sign | Exponent | Fraction |
+// To avoid 64-bit integers as much as possible, I break this into high and
+// low 32-bit chunks. High is:
+// Bit |  31  |  30-20   |      19-0     |
+//     | Sign | Exponent | High Fraction |
+// Low is:
+// Bit |     31-0     |
+//     | Low Fraction |
+// We then access the components through logical bit-wise operations to
+// extract the parts needed, with the positions and masks derived from the
+// layout shown above.
+constexpr uint64_t kSignMask = 0x8000000000000000LL;
+constexpr uint64_t kExponentMask = 0x7ff0000000000000LL;
+constexpr int32_t kExponentShift = 52;
+constexpr int32_t kExponentBias = 1023;
+constexpr uint32_t kExponentIsBadNum = 0x7ff;
+constexpr uint64_t kFractionMask = 0x000fffffffc00000LL;
+constexpr uint32_t kFractionShift = 22;
+constexpr uint32_t kFractionRoundingMask = 0x003fffff;
+constexpr uint32_t kFractionRoundingThreshold = 0x00200000;
+}  // namespace
+
 void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
                         int* shift) {
   if (double_multiplier == 0.) {
@@ -30,8 +56,16 @@
     *shift = 0;
     return;
   }
+#ifdef TFLITE_EMULATE_FLOAT
+  // If we're trying to avoid the use of floating-point instructions (for
+  // example on microcontrollers) then use an alternative implementation
+  // that only requires integer and bitwise operations. To enable this, you
+  // need to set the define during the build process for your platform.
+  int64_t q_fixed = IntegerFrExp(double_multiplier, shift);
+#else   // TFLITE_EMULATE_FLOAT
   const double q = std::frexp(double_multiplier, shift);
   auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+#endif  // TFLITE_EMULATE_FLOAT
   TFLITE_CHECK(q_fixed <= (1ll << 31));
   if (q_fixed == (1ll << 31)) {
     q_fixed /= 2;
@@ -60,6 +94,163 @@
   *left_shift = shift;
 }
 
+int64_t IntegerFrExp(double input, int* shift) {
+  // Make sure our assumptions about the double layout hold.
+  TFLITE_CHECK_EQ(8, sizeof(double));
+
+  // We want to access the bits of the input double value directly, which is
+  // tricky to do safely, so use a union to handle the casting.
+  union {
+    double double_value;
+    uint64_t double_as_uint;
+  } cast_union;
+  cast_union.double_value = input;
+  const uint64_t u = cast_union.double_as_uint;
+
+  // If the bitfield is all zeros apart from the sign bit, this is a normalized
+  // zero value, so return standard values for this special case.
+  if ((u & ~kSignMask) == 0) {
+    *shift = 0;
+    return 0;
+  }
+
+  // Deal with NaNs and Infs, which are always indicated with a fixed pattern in
+  // the exponent, and distinguished by whether the fractions are zero or
+  // non-zero.
+  const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift);
+  if (exponent_part == kExponentIsBadNum) {
+    *shift = std::numeric_limits<int>::max();
+    if (u & kFractionMask) {
+      // NaN, so just return zero (with the exponent set to INT_MAX).
+      return 0;
+    } else {
+      // Infinity, so return +/- INT_MAX.
+      if (u & kSignMask) {
+        return std::numeric_limits<int64_t>::min();
+      } else {
+        return std::numeric_limits<int64_t>::max();
+      }
+    }
+  }
+
+  // The shift is fairly easy to extract from the high bits of the double value,
+  // just by masking it out and applying a bias. The std::frexp() implementation
+  // always returns values between 0.5 and 1.0 though, whereas the exponent
+  // assumes 1.0 to 2.0 is the standard range, so I add on one to match that
+  // interface.
+  *shift = (exponent_part - kExponentBias) + 1;
+
+  // There's an implicit high bit in the double format definition, so make sure
+  // we include that at the top, and then reconstruct the rest of the fractional
+  // value from the remaining fragments.
+  int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift);
+
+  // We're cutting off some bits at the bottom, so to exactly match the standard
+  // frexp implementation here we'll apply rounding by adding one to the least
+  // significant bit of the result if the discarded portion is over half of the
+  // maximum.
+  if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) {
+    fraction += 1;
+  }
+  // Negate the fraction if the sign bit was set.
+  if (u & kSignMask) {
+    fraction *= -1;
+  }
+
+  return fraction;
+}
+
+double DoubleFromFractionAndShift(int64_t fraction, int shift) {
+  union {
+    double double_value;
+    uint64_t double_as_uint;
+  } result;
+
+  // Detect NaNs and infinities.
+  if (shift == std::numeric_limits<int>::max()) {
+    if (fraction == 0) {
+      return NAN;
+    } else if (fraction > 0) {
+      return INFINITY;
+    } else {
+      return -INFINITY;
+    }
+  }
+
+  // Return a normalized zero for a zero fraction.
+  if (fraction == 0) {
+    result.double_as_uint = 0;
+    return result.double_value;
+  }
+
+  bool is_negative = (fraction < 0);
+  int64_t encoded_fraction = is_negative ? -fraction : fraction;
+  int64_t encoded_shift = (shift - 1);
+  while (encoded_fraction < 0x40000000) {
+    encoded_fraction *= 2;
+    encoded_shift -= 1;
+  }
+  while (encoded_fraction > 0x80000000) {
+    encoded_fraction /= 2;
+    encoded_shift += 1;
+  }
+  encoded_fraction -= 0x40000000;
+  if (encoded_shift < -1022) {
+    encoded_shift = -1023;
+  } else if (encoded_shift > 1022) {
+    encoded_shift = 1023;
+  }
+  encoded_shift += kExponentBias;
+  uint64_t encoded_sign = is_negative ? kSignMask : 0;
+  result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) |
+                          (encoded_fraction << kFractionShift);
+  return result.double_value;
+}
+
+double IntegerDoubleMultiply(double a, double b) {
+  int a_shift;
+  const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+  int b_shift;
+  const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+  // Detect NaNs and infinities.
+  if (a_shift == std::numeric_limits<int>::max() ||
+      (b_shift == std::numeric_limits<int>::max())) {
+    return NAN;
+  }
+  const int result_shift = a_shift + b_shift + 1;
+  const int64_t result_fraction = (a_fraction * b_fraction) >> 32;
+  return DoubleFromFractionAndShift(result_fraction, result_shift);
+}
+
+int IntegerDoubleCompare(double a, double b) {
+  int a_shift;
+  const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+  int b_shift;
+  const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+
+  // Detect NaNs and infinities.
+  if (a_shift == std::numeric_limits<int>::max() ||
+      (b_shift == std::numeric_limits<int>::max())) {
+    return 1;
+  }
+
+  if ((a_fraction == 0) && (b_fraction < 0)) {
+    return 1;
+  } else if ((a_fraction < 0) && (b_fraction == 0)) {
+    return -1;
+  } else if (a_shift < b_shift) {
+    return -1;
+  } else if (a_shift > b_shift) {
+    return 1;
+  } else if (a_fraction < b_fraction) {
+    return -1;
+  } else if (a_fraction > b_fraction) {
+    return 1;
+  } else {
+    return 0;
+  }
+}
+
 void PreprocessSoftmaxScaling(double beta, double input_scale,
                               int input_integer_bits,
                               int32_t* quantized_multiplier, int* left_shift) {
@@ -72,8 +263,20 @@
   // result is double equivalent of Q0.31 (actually with more precision). Thus
   // this generates a Q(input_integer_bits).(31-input_integer_bits)
   // representation.
+#ifdef TFLITE_EMULATE_FLOAT
+  const double input_beta = IntegerDoubleMultiply(beta, input_scale);
+  int shift;
+  int64_t fraction = IntegerFrExp(input_beta, &shift);
+  shift += (31 - input_integer_bits);
+  double input_beta_real_multiplier =
+      DoubleFromFractionAndShift(fraction, shift);
+  if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) {
+    input_beta_real_multiplier = (1ll << 31) - 1.0;
+  }
+#else   // TFLITE_EMULATE_FLOAT
   const double input_beta_real_multiplier = std::min(
       beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+#endif  // TFLITE_EMULATE_FLOAT
 
   QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
                                    quantized_multiplier, left_shift);
@@ -97,6 +300,12 @@
 }
 
 int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+#ifdef TFLITE_EMULATE_FLOAT
+  int64_t result = (1 << input_integer_bits) - 1;
+  result <<= (31 - input_integer_bits);
+  result >>= input_left_shift;
+  return result;
+#else   // TFLITE_EMULATE_FLOAT
   const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
                                     (1ll << (31 - input_integer_bits)) /
                                     (1ll << input_left_shift);
@@ -104,6 +313,7 @@
   // After scaling the difference, the result would be at the maximum.  Thus we
   // must ensure that our value has lower magnitude.
   return static_cast<int>(std::floor(max_input_rescaled));
+#endif  // TFLITE_EMULATE_FLOAT
 }
 
 void NudgeQuantizationRange(const float min, const float max,
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 9ee4a47..d74a1ba 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -195,6 +195,44 @@
 void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
                         int* shift);
 
+// Splits a double input value into a returned fraction, and a shift value from
+// the exponent, using only bitwise and integer operations to support
+// microcontrollers and other environments without floating-point support.
+//
+// This is designed to be a replacement for how std::frexp() is used within the
+// QuantizeMultiplier() function, and so has a different signature than the
+// standard version, returning a 64-bit integer rather than a double. This
+// result has a maximum value of 1<<31, with the fraction expressed as a
+// proportion of that maximum.
+//
+// std::frexp() returns NaNs and infinities unmodified, but since we're
+// returning integers that can't represent those values, instead we return
+// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64
+// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and
+// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will
+// result in return values that end up truncating some bits at the end,
+// reflecting the loss of precision inherent in denormalization.
+int64_t IntegerFrExp(double input, int* shift);
+
+// Converts an integer fraction in the format produced by IntegerFrExp (where
+// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an
+// IEEE binary64 double format result. The implementation uses only integer and
+// bitwise operators, so no floating point hardware support or emulation is
+// needed. This is here so quantized operations can run non-time-critical
+// preparation calculations on microcontrollers and other platforms without
+// float support.
+double DoubleFromFractionAndShift(int64_t fraction, int shift);
+
+// Performs a multiplication of two numbers in double format, using only integer
+// and bitwise instructions. This is aimed at supporting housekeeping functions
+// for quantized operations on microcontrollers without floating-point hardware.
+double IntegerDoubleMultiply(double a, double b);
+
+// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is
+// greater than b. It is implemented using only integer and logical instructions
+// so that it can be easily run on microcontrollers for quantized operations.
+int IntegerDoubleCompare(double a, double b);
+
 // This first creates a multiplier in a double equivalent of
 // Q(input_integer_bits).(31-input_integer_bits) representation, with extra
 // precision in the double's fractional bits.  It then splits the result into
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 00fc3e9..14281f2 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -191,6 +191,139 @@
   EXPECT_EQ(qp.zero_point, 255);
 }
 
+TEST(QuantizationUtilTest, IntegerFrExp) {
+  int shift;
+  int64_t result = IntegerFrExp(0.0, &shift);
+  EXPECT_EQ(0, result);
+  EXPECT_EQ(0, shift);
+
+  result = IntegerFrExp(1.0, &shift);
+  EXPECT_NEAR(0x40000000, result, 1);
+  EXPECT_EQ(1, shift);
+
+  result = IntegerFrExp(0.25, &shift);
+  EXPECT_NEAR(0x40000000, result, 1);
+  EXPECT_EQ(-1, shift);
+
+  result = IntegerFrExp(-1.0, &shift);
+  EXPECT_NEAR(-(1 << 30), result, 1);
+  EXPECT_EQ(1, shift);
+
+  result = IntegerFrExp(123.45, &shift);
+  EXPECT_NEAR(2071147315, result, 1);
+  EXPECT_EQ(7, shift);
+
+  result = IntegerFrExp(NAN, &shift);
+  EXPECT_NEAR(0, result, 1);
+  EXPECT_EQ(0x7fffffff, shift);
+
+  result = IntegerFrExp(INFINITY, &shift);
+  EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
+  EXPECT_EQ(0x7fffffff, shift);
+
+  result = IntegerFrExp(-INFINITY, &shift);
+  EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
+  EXPECT_EQ(0x7fffffff, shift);
+}
+
+TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
+  int shift;
+  int32_t result = IntegerFrExp(0.0, &shift);
+  EXPECT_EQ(result, 0);
+  EXPECT_EQ(shift, 0);
+
+  int double_shift;
+  double double_result = std::frexp(0.0, &double_shift);
+  EXPECT_EQ(double_result, 0);
+  EXPECT_EQ(double_shift, 0);
+
+  result = IntegerFrExp(1.0, &shift);
+  EXPECT_NEAR(result, 0x40000000, 1);
+  EXPECT_EQ(shift, 1);
+  double_result = std::frexp(1.0, &double_shift);
+  EXPECT_NEAR(double_result, 0.5, 1e-5);
+  EXPECT_EQ(double_shift, 1);
+
+  result = IntegerFrExp(0.25, &shift);
+  EXPECT_NEAR(result, 0x40000000, 1);
+  EXPECT_EQ(shift, -1);
+  double_result = std::frexp(0.25, &double_shift);
+  EXPECT_NEAR(double_result, 0.5, 1e-5);
+  EXPECT_EQ(double_shift, -1);
+
+  result = IntegerFrExp(-1.0, &shift);
+  EXPECT_NEAR(result, -(1 << 30), 1);
+  EXPECT_EQ(shift, 1);
+  double_result = std::frexp(-1.0, &double_shift);
+  EXPECT_NEAR(double_result, -0.5, 1e-5);
+  EXPECT_EQ(double_shift, 1);
+
+  result = IntegerFrExp(123.45, &shift);
+  EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+  EXPECT_EQ(shift, 7);
+  double_result = std::frexp(123.45, &double_shift);
+  EXPECT_NEAR(double_result, 0.964453, 1e-5);
+  EXPECT_EQ(double_shift, 7);
+}
+
+TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
+  double result = DoubleFromFractionAndShift(0, 0);
+  EXPECT_EQ(0, result);
+
+  result = DoubleFromFractionAndShift(0x40000000, 1);
+  EXPECT_NEAR(1.0, result, 1e-5);
+
+  result = DoubleFromFractionAndShift(0x40000000, 2);
+  EXPECT_NEAR(2.0, result, 1e-5);
+
+  int shift;
+  int64_t fraction = IntegerFrExp(3.0, &shift);
+  result = DoubleFromFractionAndShift(fraction, shift);
+  EXPECT_NEAR(3.0, result, 1e-5);
+
+  fraction = IntegerFrExp(123.45, &shift);
+  result = DoubleFromFractionAndShift(fraction, shift);
+  EXPECT_NEAR(123.45, result, 1e-5);
+
+  fraction = IntegerFrExp(-23.232323, &shift);
+  result = DoubleFromFractionAndShift(fraction, shift);
+  EXPECT_NEAR(-23.232323, result, 1e-5);
+
+  fraction = IntegerFrExp(NAN, &shift);
+  result = DoubleFromFractionAndShift(fraction, shift);
+  EXPECT_TRUE(std::isnan(result));
+
+  fraction = IntegerFrExp(INFINITY, &shift);
+  result = DoubleFromFractionAndShift(fraction, shift);
+  EXPECT_FALSE(std::isfinite(result));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
+  EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
+  EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
+  EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
+  EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
+  EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
+  EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
+  EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
+  EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
+  EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
+  EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
+  EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
+  EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleCompare) {
+  EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
+  EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
+  EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
+  EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
+  EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
+  EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
+  EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
+  EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
+}
+
 #ifdef GTEST_HAS_DEATH_TEST
 TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
   EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5..bb5d590 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -25,8 +25,9 @@
 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
                           const float* filter_data, const Dims<4>& filter_dims,
                           const float* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height, int pad_width,
-                          int pad_height, int depth_multiplier,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
                           float output_activation_min,
                           float output_activation_max, float* output_data,
                           const Dims<4>& output_dims) {
@@ -52,8 +53,9 @@
             float total = 0.f;
             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
-                const int in_x = in_x_origin + filter_x;
-                const int in_y = in_y_origin + filter_y;
+                const int in_x = in_x_origin + dilation_width_factor * filter_x;
+                const int in_y =
+                    in_y_origin + dilation_height_factor * filter_y;
                 // If the location is outside the bounds of the input image,
                 // use zero as a default value.
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -81,6 +83,20 @@
   }
 }
 
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+                          const float* filter_data, const Dims<4>& filter_dims,
+                          const float* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height, int pad_width,
+                          int pad_height, int depth_multiplier,
+                          float output_activation_min,
+                          float output_activation_max, float* output_data,
+                          const Dims<4>& output_dims) {
+  DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+                bias_dims, stride_width, stride_height, 1, 1, pad_width,
+                pad_height, depth_multiplier, output_activation_min,
+                output_activation_max, output_data, output_dims);
+}
+
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index d577392..5e3e899 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -30,8 +30,9 @@
                           int32 input_offset, const uint8* filter_data,
                           const Dims<4>& filter_dims, int32 filter_offset,
                           const int32* bias_data, const Dims<4>& bias_dims,
-                          int stride_width, int stride_height, int pad_width,
-                          int pad_height, int depth_multiplier,
+                          int stride_width, int stride_height,
+                          int dilation_width_factor, int dilation_height_factor,
+                          int pad_width, int pad_height, int depth_multiplier,
                           int32 output_offset, int32 output_multiplier,
                           int output_shift, int32 output_activation_min,
                           int32 output_activation_max, uint8* output_data,
@@ -58,8 +59,9 @@
             int32 acc = 0;
             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
-                const int in_x = in_x_origin + filter_x;
-                const int in_y = in_y_origin + filter_y;
+                const int in_x = in_x_origin + dilation_width_factor * filter_x;
+                const int in_y =
+                    in_y_origin + dilation_height_factor * filter_y;
                 // If the location is outside the bounds of the input image,
                 // use zero as a default value.
                 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
@@ -90,6 +92,24 @@
   }
 }
 
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+                          int32 input_offset, const uint8* filter_data,
+                          const Dims<4>& filter_dims, int32 filter_offset,
+                          const int32* bias_data, const Dims<4>& bias_dims,
+                          int stride_width, int stride_height, int pad_width,
+                          int pad_height, int depth_multiplier,
+                          int32 output_offset, int32 output_multiplier,
+                          int output_shift, int32 output_activation_min,
+                          int32 output_activation_max, uint8* output_data,
+                          const Dims<4>& output_dims) {
+  DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+                filter_offset, bias_data, bias_dims, stride_width,
+                stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+                output_offset, output_multiplier, output_shift,
+                output_activation_min, output_activation_max, output_data,
+                output_dims);
+}
+
 // Legacy, for compatibility with old checked-in code.
 template <FusedActivationFunctionType Ac>
 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index e79e75a..77e60ad 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -16,7 +16,7 @@
 #include <string.h>
 #include <algorithm>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/round.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -173,6 +173,16 @@
   }
 }
 
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                                  float* batch_vector) {
+  for (int b = 0; b < n_batch; b++) {
+    for (int i = 0; i < v_size; ++i) {
+      batch_vector[i] += vector[i];
+    }
+    batch_vector += v_size;
+  }
+}
+
 void PortableVectorBatchVectorAssign(const float* vector, int v_size,
                                      int n_batch, float* batch_vector) {
   for (int b = 0; b < n_batch; b++) {
@@ -243,5 +253,31 @@
   }
 }
 
+void PortableMeanStddevNormalization(const float* input_vector,
+                                     float* output_vector, int v_size,
+                                     int n_batch, float normalization_epsilon) {
+  for (int batch = 0; batch < n_batch; ++batch) {
+    float sum = 0.0f;
+    float sum_sq = 0.0f;
+    for (int i = 0; i < v_size; ++i) {
+      sum += input_vector[i];
+      sum_sq += input_vector[i] * input_vector[i];
+    }
+    const float mean = sum / v_size;
+    float stddev_inv = 0.0f;
+    const float variance = sum_sq / v_size - mean * mean;
+    if (variance == 0) {
+      stddev_inv = 1.0f / sqrt(normalization_epsilon);
+    } else {
+      stddev_inv = 1.0f / sqrt(variance);
+    }
+    for (int i = 0; i < v_size; ++i) {
+      output_vector[i] = (input_vector[i] - mean) * stddev_inv;
+    }
+    input_vector += v_size;
+    output_vector += v_size;
+  }
+}
+
 }  // namespace tensor_utils
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index 3829be0..714b116 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,7 @@
 
 // TODO(ghodrat): Remove this header file and the dependency to internal data
 // structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 #if defined(_MSC_VER)
 #define __restrict__ __restrict
@@ -87,6 +87,10 @@
 void PortableVectorBatchVectorAssign(const float* vector, int v_size,
                                      int n_batch, float* batch_vector);
 
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                                  float* batch_vector);
+
 // Apply sigmoid to elements of a vector.
 void PortableApplySigmoidToVector(const float* vector, int v_size,
                                   float* result);
@@ -125,6 +129,12 @@
 void PortableReductionSumVector(const float* input_vector, float* output_vector,
                                 int output_size, int reduction_size);
 
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void PortableMeanStddevNormalization(const float* input_vector,
+                                     float* output_vector, int v_size,
+                                     int n_batch, float normalization_epsilon);
+
 float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
 
 bool IsZeroVector(const float* vector, int v_size) {
@@ -193,6 +203,11 @@
                                            result, result_stride);
 }
 
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                          float* batch_vector) {
+  PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
 void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
                              float* batch_vector) {
   PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -240,6 +255,13 @@
                              reduction_size);
 }
 
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+                             int v_size, int n_batch,
+                             float normalization_epsilon) {
+  PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+                                  normalization_epsilon);
+}
+
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e5b71f8..66f18ec 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -419,9 +419,9 @@
                          T* output_data) {
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int input_depth = input_shape.Dims(3);
@@ -472,9 +472,9 @@
                          T* output_data) {
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int input_depth = input_shape.Dims(3);
@@ -1117,7 +1117,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1158,7 +1158,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1200,7 +1200,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1350,7 +1350,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
@@ -1483,7 +1483,7 @@
   // The input shapes are extended as part of NdArrayDesc initialization.
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
@@ -1579,7 +1579,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
@@ -1713,7 +1713,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1754,7 +1754,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1818,7 +1818,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1858,7 +1858,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1897,7 +1897,7 @@
   NdArrayDesc<4> desc2;
   NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
                                       &desc2);
-  RuntimeShape extended_output_shape =
+  const RuntimeShape extended_output_shape =
       RuntimeShape::ExtendedShape(4, output_shape);
 
   // In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -3398,10 +3398,12 @@
   }
 }
 
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
-                       int32 zero_point, double scale, float* output_data,
-                       const Dims<4>& output_dims) {
-  const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+                       const RuntimeShape& input_shape, const uint8* input_data,
+                       const RuntimeShape& output_shape, float* output_data) {
+  int32 zero_point = op_params.zero_point;
+  double scale = op_params.scale;
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
 
   for (int i = 0; i < flat_size; i++) {
     int32 val = input_data[i];
@@ -3410,9 +3412,25 @@
   }
 }
 
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
-                      float rmin, float rmax, int num_bits, float* output_data,
-                      const Dims<4>& output_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+                       int32 zero_point, double scale, float* output_data,
+                       const Dims<4>& output_dims) {
+  tflite::DequantizationParams op_params;
+  op_params.zero_point = zero_point;
+  op_params.scale = scale;
+
+  Dequantize(op_params, DimsToShape(input_dims), input_data,
+             DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const tflite::FakeQuantParams& op_params,
+                      const RuntimeShape& input_shape, const float* input_data,
+                      const RuntimeShape& output_shape, float* output_data) {
+  float rmin = op_params.minmax.min;
+  float rmax = op_params.minmax.max;
+  int num_bits = op_params.num_bits;
   // 0 should always be a representable value. Let's assume that the initial
   // min,max range contains 0.
   TFLITE_DCHECK_LE(rmin, 0.0f);
@@ -3425,11 +3443,25 @@
   float nudged_min, nudged_max, nudged_scale;
   NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
                          &nudged_max, &nudged_scale);
-  const int flat_size = MatchingFlatSize(output_dims, input_dims);
+  const int flat_size = MatchingFlatSize(input_shape, output_shape);
   FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
                     output_data, flat_size);
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+                      float rmin, float rmax, int num_bits, float* output_data,
+                      const Dims<4>& output_dims) {
+  tflite::FakeQuantParams op_params;
+  op_params.num_bits = num_bits;
+  op_params.minmax.min = rmin;
+  op_params.minmax.max = rmax;
+
+  FakeQuant(op_params, DimsToShape(input_dims), input_data,
+            DimsToShape(output_dims), output_data);
+}
+
 template <typename SrcT, typename DstT>
 inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
                  const RuntimeShape& output_shape, DstT* output_data) {
@@ -3452,23 +3484,54 @@
 }
 
 template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
-                   int input_rank, const int32* coords_data,
-                   const Dims<4>& coords_dims, T* output_data,
-                   const Dims<4>& output_dims) {
-  TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
-  int stride = input_dims.strides[input_rank - 1];
+inline void Gather(const tflite::GatherParams& op_params,
+                   const RuntimeShape& input_shape, const T* input_data,
+                   const RuntimeShape& coords_shape, const int32* coords_data,
+                   const RuntimeShape& output_shape, T* output_data) {
+  // Enable these checks when moving legacy ops to legacy_reference_ops.
+  //
+  // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+  const int input_rank = op_params.input_rank;
+  const int gather_dimensions = output_shape.DimensionsCount();
+  TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+  const int axis = gather_dimensions - input_rank;
+  TFLITE_DCHECK_LT(axis, gather_dimensions);
+  TFLITE_DCHECK_GE(axis, 0);
+  const int coords_count = coords_shape.FlatSize();
+  TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis));
+
+  int64_t stride = 1;
+  for (int i = axis + 1; i < gather_dimensions; ++i) {
+    stride *= input_shape.Dims(i);
+  }
   T* out = output_data;
 
-  for (int i = 0; i < coords_dims.sizes[0]; i++) {
+  for (int i = 0; i < coords_count; ++i) {
     TFLITE_DCHECK_GE(coords_data[i], 0);
-    TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+    TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis));
     const T* in = input_data + coords_data[i] * stride;
     memcpy(out, in, sizeof(T) * stride);
     out += stride;
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4> version.
+// When moving legacy ops to legacy_reference_ops, replace content with looser
+// implementation.
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+                   int input_rank, const int32* coords_data,
+                   const Dims<4>& coords_dims, T* output_data,
+                   const Dims<4>& output_dims) {
+  tflite::GatherParams op_params;
+  op_params.input_rank = input_rank;
+
+  Gather(op_params, DimsToShape(input_dims), input_data,
+         DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+         output_data);
+}
+
 template <typename T>
 inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
                            const RuntimeShape& unextended_input_shape,
@@ -3480,11 +3543,11 @@
   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input_shape =
+  const RuntimeShape input_shape =
       RuntimeShape::ExtendedShape(4, unextended_input_shape);
-  RuntimeShape output_size_shape =
+  const RuntimeShape output_size_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -3543,9 +3606,9 @@
     const RuntimeShape& unextended_output_shape, T* output_data) {
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input1_shape =
+  const RuntimeShape input1_shape =
       RuntimeShape::ExtendedShape(4, unextended_input1_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int depth = input1_shape.Dims(3);
@@ -3600,9 +3663,9 @@
     const RuntimeShape& unextended_output_shape, T* output_data) {
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape input1_shape =
+  const RuntimeShape input1_shape =
       RuntimeShape::ExtendedShape(4, unextended_input1_shape);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   const int output_width = output_shape.Dims(2);
@@ -3656,8 +3719,10 @@
                     const RuntimeShape& input_shape, const T* input_data,
                     const P* pad_value_ptr, const RuntimeShape& output_shape,
                     T* output_data) {
-  RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
-  RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+  const RuntimeShape ext_input_shape =
+      RuntimeShape::ExtendedShape(4, input_shape);
+  const RuntimeShape ext_output_shape =
+      RuntimeShape::ExtendedShape(4, output_shape);
   TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
   TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
 
@@ -3744,63 +3809,115 @@
 }
 
 template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
-                         int begin_mask, int end_mask, int shrink_axis_mask,
-                         const std::vector<int>& start_indices,
-                         const std::vector<int>& stop_indices,
-                         const std::vector<int>& strides, T* output_data,
-                         const Dims<4>& output_dims) {
-  // Note that the axis orders are reversed for runtime ops, so the indices,
-  // strides and masks must be as well too.
-  TFLITE_DCHECK_EQ(start_indices.size(), 4);
-  TFLITE_DCHECK_EQ(stop_indices.size(), 4);
-  TFLITE_DCHECK_EQ(strides.size(), 4);
-  const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
-                                                  strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+                         const RuntimeShape& unextended_input_shape,
+                         const T* input_data,
+                         const RuntimeShape& unextended_output_shape,
+                         T* output_data) {
+  // Note that the output_shape is not used herein.
+  tflite::StridedSliceParams params_copy = op_params;
+
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  // Reverse and pad to 4 dimensions because that is what the runtime code
+  // requires (ie. all shapes must be 4D and are given backwards).
+  strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+  const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
   const int stop_b =
-      strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
-                                 strides, input_dims.sizes, 3, start_b);
-  const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
-                                                  strides, input_dims.sizes, 2);
+      strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+  const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
   const int stop_h =
-      strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
-                                 strides, input_dims.sizes, 2, start_h);
-  const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
-                                                  strides, input_dims.sizes, 1);
+      strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+  const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
   const int stop_w =
-      strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
-                                 strides, input_dims.sizes, 1, start_w);
-  const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
-                                                  strides, input_dims.sizes, 0);
+      strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+  const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
   const int stop_d =
-      strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
-                                 strides, input_dims.sizes, 0, start_d);
+      strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
 
   T* out_ptr = output_data;
   for (int in_b = start_b;
-       !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
-       in_b += strides[3]) {
+       !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+       in_b += params_copy.strides[0]) {
     for (int in_h = start_h;
-         !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
-         in_h += strides[2]) {
+         !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+         in_h += params_copy.strides[1]) {
       for (int in_w = start_w;
-           !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
-           in_w += strides[1]) {
-        for (int in_d = start_d;
-             !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
-             in_d += strides[0]) {
-          *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+           !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+           in_w += params_copy.strides[2]) {
+        for (int in_d = start_d; !strided_slice::LoopCondition(
+                 in_d, stop_d, params_copy.strides[3]);
+             in_d += params_copy.strides[3]) {
+          *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
         }
       }
     }
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+  n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+  n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+  n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+  return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+          ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+  TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+  TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+  std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+  std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+  std::reverse(p->strides, p->strides + p->strides_count);
+
+  p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+                  (32 - p->start_indices_count);
+  p->ellipsis_mask =
+      LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+      (32 - p->start_indices_count);
+  p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+                (32 - p->start_indices_count);
+  p->new_axis_mask =
+      LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+      (32 - p->start_indices_count);
+  p->shrink_axis_mask =
+      LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+      (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+                         int begin_mask, int end_mask, int shrink_axis_mask,
+                         const std::vector<int>& start_indices,
+                         const std::vector<int>& stop_indices,
+                         const std::vector<int>& strides, T* output_data,
+                         const Dims<4>& output_dims) {
+  TFLITE_DCHECK_EQ(start_indices.size(), 4);
+  auto op_params = strided_slice::BuildStridedSliceParams(
+      begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+      strides);
+  StridedSliceReverseIndices(&op_params);
+
+  StridedSlice(op_params, DimsToShape(input_dims), input_data,
+               DimsToShape(output_dims), output_data);
+}
+
 template <typename T>
 inline void Slice(const tflite::SliceParams& op_params,
                   const RuntimeShape& input_shape, const T* input_data,
                   const RuntimeShape& output_shape, T* output_data) {
-  RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+  const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
   // TODO(dkalenichenko): This op only supports 4D tensors or smaller.
   TFLITE_DCHECK_LE(op_params.begin_count, 4);
   TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -4018,22 +4135,32 @@
 }
 
 template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
-                 const std::vector<int>& reduction_indices, T* output_data,
-                 const Dims<4>& output_dims) {
-  const int output_batch = ArraySize(output_dims, 3);
-  const int output_height = ArraySize(output_dims, 2);
-  const int output_width = ArraySize(output_dims, 1);
-  const int output_depth = ArraySize(output_dims, 0);
+inline void Mean(const tflite::MeanParams& op_params,
+                 const RuntimeShape& unextended_input_shape,
+                 const T* input_data,
+                 const RuntimeShape& unextended_output_shape, T* output_data) {
+  gemmlowp::ScopedProfilingLabel label("Mean");
 
-  const int input_height = ArraySize(input_dims, 2);
-  const int input_width = ArraySize(input_dims, 1);
+  TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape input_shape =
+      RuntimeShape::ExtendedShape(4, unextended_input_shape);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  const int output_batch = output_shape.Dims(0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  const int output_depth = output_shape.Dims(3);
+
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
 
   // The current implementation only supports simultaneous reduction over
   // width and height.
-  TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
-  TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
-                (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+  TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+  TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+                (op_params.axis[0] == 2 && op_params.axis[1] == 1));
   TFLITE_DCHECK_EQ(output_height, 1);
   TFLITE_DCHECK_EQ(output_width, 1);
 
@@ -4042,15 +4169,31 @@
       float value = 0;
       for (int in_h = 0; in_h < input_height; ++in_h) {
         for (int in_w = 0; in_w < input_width; ++in_w) {
-          value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+          value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
         }
       }
-      output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+      output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
           value / (input_width * input_height);
     }
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+                 const std::vector<int>& reduction_indices, T* output_data,
+                 const Dims<4>& output_dims) {
+  tflite::MeanParams op_params;
+  op_params.axis_count = reduction_indices.size();
+  for (int i = 0; i < op_params.axis_count; ++i) {
+    op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+  }
+
+  Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+       output_data);
+}
+
 // Computes the mean of elements across dimensions given in axis.
 // It does so in two stages, first calculates the sum of elements along the axis
 // then divides it by the number of element in axis for quantized values.
@@ -4149,7 +4292,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
@@ -4337,9 +4480,10 @@
 using ComparisonFn = bool (*)(T, T);
 
 template <typename T, ComparisonFn<T> F>
-inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
-                       const RuntimeShape& input2_shape, const T* input2_data,
-                       const RuntimeShape& output_shape, bool* output_data) {
+inline void ComparisonImpl(
+    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+    const T* input1_data, const RuntimeShape& input2_shape,
+    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
   const int64_t flatsize =
       MatchingFlatSize(input1_shape, input2_shape, output_shape);
   for (int64_t i = 0; i < flatsize; ++i) {
@@ -4347,16 +4491,63 @@
   }
 }
 
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+                       const RuntimeShape& input1_shape,
+                       const float* input1_data,
+                       const RuntimeShape& input2_shape,
+                       const float* input2_data,
+                       const RuntimeShape& output_shape, bool* output_data) {
+  ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+                           input2_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <typename T, ComparisonFn<T> F>
 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
                        const T* input2_data, const Dims<4>& input2_dims,
                        bool* output_data, const Dims<4>& output_dims) {
-  Comparison<T, F>(DimsToShape(input1_dims), input1_data,
-                   DimsToShape(input2_dims), input2_data,
-                   DimsToShape(output_dims), output_data);
+  ComparisonParams op_params;
+  // No parameters needed.
+  ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+                       DimsToShape(input2_dims), input2_data,
+                       DimsToShape(output_dims), output_data);
 }
 
 template <typename T, ComparisonFn<int32> F>
+inline void ComparisonWithScaling(
+    const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+    const T* input1_data, const RuntimeShape& input2_shape,
+    const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+  int left_shift = op_params.left_shift;
+  int32 input1_offset = op_params.input1_offset;
+  int32 input1_multiplier = op_params.input1_multiplier;
+  int input1_shift = op_params.input1_shift;
+  int32 input2_offset = op_params.input2_offset;
+  int32 input2_multiplier = op_params.input2_multiplier;
+  int input2_shift = op_params.input2_shift;
+
+  const int64_t flatsize =
+      MatchingFlatSize(input1_shape, input2_shape, output_shape);
+  for (int64_t i = 0; i < flatsize; ++i) {
+    const int32 input1_val = input1_offset + input1_data[i];
+    const int32 input2_val = input2_offset + input2_data[i];
+    const int32 shifted_input1_val = input1_val * (1 << left_shift);
+    const int32 shifted_input2_val = input2_val * (1 << left_shift);
+    const int32 scaled_input1_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input1_val, input1_multiplier, input1_shift);
+    const int32 scaled_input2_val =
+        MultiplyByQuantizedMultiplierSmallerThanOneExp(
+            shifted_input2_val, input2_multiplier, input2_shift);
+    output_data[i] = F(scaled_input1_val, scaled_input2_val);
+  }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
 inline void Comparison(int left_shift, const T* input1_data,
                        const Dims<4>& input1_dims, int32 input1_offset,
                        int32 input1_multiplier, int input1_shift,
@@ -4364,47 +4555,131 @@
                        int32 input2_offset, int32 input2_multiplier,
                        int input2_shift, bool* output_data,
                        const Dims<4>& output_dims) {
-  const int64_t flatsize =
-      MatchingFlatSize(input1_dims, input2_dims, output_dims);
-  for (int64_t i = 0; i < flatsize; ++i) {
-    const int32 input1_val = input1_offset + input1_data[i];
-    const int32 input2_val = input2_offset + input2_data[i];
-    const int32 shifted_input1_val = input1_val * (1 << left_shift);
-    const int32 shifted_input2_val = input2_val * (1 << left_shift);
-    const int32 scaled_input1_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input1_val, input1_multiplier,
-            kReverseShift * input1_shift);
-    const int32 scaled_input2_val =
-        MultiplyByQuantizedMultiplierSmallerThanOneExp(
-            shifted_input2_val, input2_multiplier,
-            kReverseShift * input2_shift);
-    output_data[i] = F(scaled_input1_val, scaled_input2_val);
-  }
+  tflite::ComparisonParams op_params;
+  op_params.left_shift = left_shift;
+  op_params.input1_offset = input1_offset;
+  op_params.input1_multiplier = input1_multiplier;
+  op_params.input1_shift = kReverseShift * input1_shift;
+  op_params.input2_offset = input2_offset;
+  op_params.input2_multiplier = input2_multiplier;
+  op_params.input2_shift = kReverseShift * input2_shift;
+
+  ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+                              DimsToShape(input2_dims), input2_data,
+                              DimsToShape(output_dims), output_data);
 }
 
 template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison4DSlowImpl(
+    const ComparisonParams& op_params,
+    const RuntimeShape& unextended_input1_shape, const T* input1_data,
+    const RuntimeShape& unextended_input2_shape, const T* input2_data,
+    const RuntimeShape& unextended_output_shape, bool* output_data) {
+  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          output_data[Offset(output_shape, b, y, x, c)] =
+              F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+                input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
+        }
+      }
+    }
+  }
+}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+                                      const RuntimeShape& input1_shape,
+                                      const float* input1_data,
+                                      const RuntimeShape& input2_shape,
+                                      const float* input2_data,
+                                      const RuntimeShape& output_shape,
+                                      bool* output_data) {
+  BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+                                          input2_shape, input2_data,
+                                          output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<T> F>
 inline void BroadcastComparison(const T* input1_data,
                                 const Dims<4>& input1_dims,
                                 const T* input2_data,
                                 const Dims<4>& input2_dims, bool* output_data,
                                 const Dims<4>& output_dims) {
+  ComparisonParams op_params;
+  // No parameters needed.
+  BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+                                      input1_data, DimsToShape(input2_dims),
+                                      input2_data, DimsToShape(output_dims),
+                                      output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison4DSlowWithScaling(
+    const ComparisonParams& op_params,
+    const RuntimeShape& unextended_input1_shape, const T* input1_data,
+    const RuntimeShape& unextended_input2_shape, const T* input2_data,
+    const RuntimeShape& unextended_output_shape, bool* output_data) {
+  gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-  for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
-    for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
-      for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
-        for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
-          output_data[Offset(output_dims, c, x, y, b)] =
-              F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
-                input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
+
+  int left_shift = op_params.left_shift;
+  int32 input1_offset = op_params.input1_offset;
+  int32 input1_multiplier = op_params.input1_multiplier;
+  int input1_shift = op_params.input1_shift;
+  int32 input2_offset = op_params.input2_offset;
+  int32 input2_multiplier = op_params.input2_multiplier;
+  int input2_shift = op_params.input2_shift;
+
+  for (int b = 0; b < output_shape.Dims(0); ++b) {
+    for (int y = 0; y < output_shape.Dims(1); ++y) {
+      for (int x = 0; x < output_shape.Dims(2); ++x) {
+        for (int c = 0; c < output_shape.Dims(3); ++c) {
+          const int32 input1_val =
+              input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
+          const int32 input2_val =
+              input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
+          const int32 shifted_input1_val = input1_val * (1 << left_shift);
+          const int32 shifted_input2_val = input2_val * (1 << left_shift);
+          const int32 scaled_input1_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input1_val, input1_multiplier, input1_shift);
+          const int32 scaled_input2_val =
+              MultiplyByQuantizedMultiplierSmallerThanOneExp(
+                  shifted_input2_val, input2_multiplier, input2_shift);
+          output_data[Offset(output_shape, b, y, x, c)] =
+              F(scaled_input1_val, scaled_input2_val);
         }
       }
     }
   }
 }
 
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
 template <typename T, ComparisonFn<int32> F>
 inline void BroadcastComparison(int left_shift, const T* input1_data,
                                 const Dims<4>& input1_dims, int32 input1_offset,
@@ -4413,80 +4688,107 @@
                                 const Dims<4>& input2_dims, int32 input2_offset,
                                 int32 input2_multiplier, int input2_shift,
                                 bool* output_data, const Dims<4>& output_dims) {
-  NdArrayDesc<4> desc1;
-  NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-  for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
-    for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
-      for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
-        for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
-          const int32 input1_val =
-              input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
-          const int32 input2_val =
-              input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
-          const int32 shifted_input1_val = input1_val * (1 << left_shift);
-          const int32 shifted_input2_val = input2_val * (1 << left_shift);
-          const int32 scaled_input1_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input1_val, input1_multiplier,
-                  kReverseShift * input1_shift);
-          const int32 scaled_input2_val =
-              MultiplyByQuantizedMultiplierSmallerThanOneExp(
-                  shifted_input2_val, input2_multiplier,
-                  kReverseShift * input2_shift);
-          output_data[Offset(output_dims, c, x, y, b)] =
-              F(scaled_input1_val, scaled_input2_val);
-        }
-      }
-    }
-  }
+  ComparisonParams op_params;
+
+  op_params.left_shift = left_shift;
+  op_params.input1_offset = input1_offset;
+  op_params.input1_multiplier = input1_multiplier;
+  op_params.input1_shift = kReverseShift * input1_shift;
+  op_params.input2_offset = input2_offset;
+  op_params.input2_multiplier = input2_multiplier;
+  op_params.input2_shift = kReverseShift * input2_shift;
+
+  BroadcastComparison4DSlowWithScaling<T, F>(
+      op_params, DimsToShape(input1_dims), input1_data,
+      DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+      output_data);
 }
 
-#define TFLITE_COMPARISON_OP(name)                                            \
-  template <typename T>                                                       \
-  inline void name(const T* input1_data, const Dims<4>& input1_dims,          \
-                   const T* input2_data, const Dims<4>& input2_dims,          \
-                   bool* output_data, const Dims<4>& output_dims) {           \
-    gemmlowp::ScopedProfilingLabel label(#name);                              \
-    Comparison<T, name##Fn>(input1_data, input1_dims, input2_data,            \
-                            input2_dims, output_data, output_dims);           \
-  }                                                                           \
-  template <typename T>                                                       \
-  inline void name(                                                           \
-      int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
-      int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
-      const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
-      int32 input2_multiplier, int input2_shift, bool* output_data,           \
-      const Dims<4>& output_dims) {                                           \
-    gemmlowp::ScopedProfilingLabel label(#name "/8bit");                      \
-    Comparison<T, name##Fn>(left_shift, input1_data, input1_dims,             \
-                            input1_offset, input1_multiplier, input1_shift,   \
-                            input2_data, input2_dims, input2_offset,          \
-                            input2_multiplier, input2_shift, output_data,     \
-                            output_dims);                                     \
-  }                                                                           \
-  template <typename T>                                                       \
-  inline void Broadcast##name(                                                \
-      const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
-      const Dims<4>& input2_dims, bool* output_data,                          \
-      const Dims<4>& output_dims) {                                           \
-    gemmlowp::ScopedProfilingLabel label("Broadcast" #name);                  \
-    BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data,   \
-                                     input2_dims, output_data, output_dims);  \
-  }                                                                           \
-  template <typename T>                                                       \
-  inline void Broadcast##name(                                                \
-      int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
-      int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
-      const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
-      int32 input2_multiplier, int input2_shift, bool* output_data,           \
-      const Dims<4>& output_dims) {                                           \
-    gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit");          \
-    BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims,    \
-                                     input1_offset, input1_multiplier,        \
-                                     input1_shift, input2_data, input2_dims,  \
-                                     input2_offset, input2_multiplier,        \
-                                     input2_shift, output_data, output_dims); \
+#define TFLITE_COMPARISON_OP(name)                                             \
+  template <typename T>                                                        \
+  inline void name(const T* input1_data, const Dims<4>& input1_dims,           \
+                   const T* input2_data, const Dims<4>& input2_dims,           \
+                   bool* output_data, const Dims<4>& output_dims) {            \
+    gemmlowp::ScopedProfilingLabel label(#name);                               \
+    Comparison<T, name##Fn>(input1_data, input1_dims, input2_data,             \
+                            input2_dims, output_data, output_dims);            \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void name(                                                            \
+      int left_shift, const T* input1_data, const Dims<4>& input1_dims,        \
+      int32 input1_offset, int32 input1_multiplier, int input1_shift,          \
+      const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,   \
+      int32 input2_multiplier, int input2_shift, bool* output_data,            \
+      const Dims<4>& output_dims) {                                            \
+    gemmlowp::ScopedProfilingLabel label(#name "/8bit");                       \
+    Comparison<T, name##Fn>(left_shift, input1_data, input1_dims,              \
+                            input1_offset, input1_multiplier, input1_shift,    \
+                            input2_data, input2_dims, input2_offset,           \
+                            input2_multiplier, input2_shift, output_data,      \
+                            output_dims);                                      \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void Broadcast##name(                                                 \
+      const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,  \
+      const Dims<4>& input2_dims, bool* output_data,                           \
+      const Dims<4>& output_dims) {                                            \
+    gemmlowp::ScopedProfilingLabel label("Broadcast" #name);                   \
+    BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data,    \
+                                     input2_dims, output_data, output_dims);   \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void Broadcast##name(                                                 \
+      int left_shift, const T* input1_data, const Dims<4>& input1_dims,        \
+      int32 input1_offset, int32 input1_multiplier, int input1_shift,          \
+      const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,   \
+      int32 input2_multiplier, int input2_shift, bool* output_data,            \
+      const Dims<4>& output_dims) {                                            \
+    gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit");           \
+    BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims,     \
+                                     input1_offset, input1_multiplier,         \
+                                     input1_shift, input2_data, input2_dims,   \
+                                     input2_offset, input2_multiplier,         \
+                                     input2_shift, output_data, output_dims);  \
+  }                                                                            \
+  inline void name(const ComparisonParams& op_params,                          \
+                   const RuntimeShape& input1_shape, const float* input1_data, \
+                   const RuntimeShape& input2_shape, const float* input2_data, \
+                   const RuntimeShape& output_shape, bool* output_data) {      \
+    gemmlowp::ScopedProfilingLabel label(#name);                               \
+    Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape,   \
+                         input2_data, output_shape, output_data);              \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void name##WithScaling(                                               \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label(#name "/8bit");                       \
+    ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data,   \
+                                       input2_shape, input2_data,              \
+                                       output_shape, output_data);             \
+  }                                                                            \
+  inline void Broadcast4DSlow##name(                                           \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const float* input1_data, const RuntimeShape& input2_shape,              \
+      const float* input2_data, const RuntimeShape& output_shape,              \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label("Broadcast" #name);                   \
+    BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data,  \
+                                        input2_shape, input2_data,             \
+                                        output_shape, output_data);            \
+  }                                                                            \
+  template <typename T>                                                        \
+  inline void Broadcast4DSlow##name##WithScaling(                              \
+      const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
+      const T* input1_data, const RuntimeShape& input2_shape,                  \
+      const T* input2_data, const RuntimeShape& output_shape,                  \
+      bool* output_data) {                                                     \
+    gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit");           \
+    BroadcastComparison4DSlowWithScaling<T, name##Fn>(                         \
+        op_params, input1_shape, input1_data, input2_shape, input2_data,       \
+        output_shape, output_data);                                            \
   }
 TFLITE_COMPARISON_OP(Equal);
 TFLITE_COMPARISON_OP(NotEqual);
@@ -4577,16 +4879,22 @@
 }
 
 template <typename T>
-inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
+inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
                                const T* input1_data,
-                               const RuntimeShape& input2_shape,
+                               const RuntimeShape& unextended_input2_shape,
                                const T* input2_data,
-                               const RuntimeShape& output_shape,
+                               const RuntimeShape& unextended_output_shape,
                                T* output_data) {
+  TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+  const RuntimeShape output_shape =
+      RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
   NdArrayDesc<4> desc1;
   NdArrayDesc<4> desc2;
-  NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
-                                      &desc2);
+  NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+                                      unextended_input2_shape, &desc1, &desc2);
 
   for (int b = 0; b < output_shape.Dims(0); ++b) {
     for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4623,7 +4931,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
@@ -4662,7 +4970,7 @@
   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
-  RuntimeShape output_shape =
+  const RuntimeShape output_shape =
       RuntimeShape::ExtendedShape(4, unextended_output_shape);
 
   NdArrayDesc<4> desc1;
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index 5994fad..af5db10 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@
 #include <limits>
 #include <vector>
 #include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
-
 namespace strided_slice {
 
 // Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@
   return v;
 }
 
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+                                   int dim_count) {
+  // Add indices and mask bits to fully include extra dimensions
+  TFLITE_CHECK_LE(dim_count, 4);
+  TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+  TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+  TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+  const int pad_count = dim_count - p->start_indices_count;
+
+  // Pad indices at start, so move arrays by pad_count.
+  for (int i = p->start_indices_count - 1; i > 0; --i) {
+    p->strides[i + pad_count] = p->strides[i];
+    p->start_indices[i + pad_count] = p->start_indices[i];
+    p->stop_indices[i + pad_count] = p->stop_indices[i];
+  }
+  for (int i = 0; i < pad_count; ++i) {
+    p->start_indices[i] = 0;
+    p->stop_indices[i] = 0;
+    p->strides[i] = 1;
+  }
+
+  // Pad masks with 0s or 1s as required.
+  p->shrink_axis_mask <<= pad_count;
+  p->ellipsis_mask <<= pad_count;
+  p->new_axis_mask <<= pad_count;
+  p->begin_mask <<= pad_count;
+  p->end_mask <<= pad_count;
+  p->begin_mask |= (1 << pad_count) - 1;
+  p->end_mask |= (1 << pad_count) - 1;
+
+  p->start_indices_count = dim_count;
+  p->stop_indices_count = dim_count;
+  p->strides_count = dim_count;
+}
+
 // Return the index for the first element along that axis. This index will be a
 // positive integer between [0, axis_size - 1] that can be used to index
 // directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
-                        std::vector<IntType> const& start_indices,
-                        std::vector<IntType> const& strides,
-                        int const* input_shape, int axis) {
-  // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+                        const RuntimeShape& input_shape, int axis) {
+  const auto begin_mask = params.begin_mask;
+  const auto* start_indices = params.start_indices;
+  const auto* strides = params.strides;
+  // Begin with the specified index.
   int start = start_indices[axis];
 
   // begin_mask override
@@ -57,7 +93,7 @@
   }
 
   // Handle negative indices
-  int axis_size = input_shape[axis];
+  int axis_size = input_shape.Dims(axis);
   if (start < 0) {
     start += axis_size;
   }
@@ -73,11 +109,14 @@
 // element. ie. So if you were iterating through all elements of a 1D array of
 // size 4, this function would return 4 as the stop, because it is one past the
 // "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, int shrink_axis_mask,
-                       std::vector<IntType> const& stop_indices,
-                       std::vector<IntType> const& strides,
-                       int const* input_shape, int axis, int start_for_axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+                       const RuntimeShape& input_shape, int axis,
+                       int start_for_axis) {
+  const auto end_mask = params.end_mask;
+  const auto shrink_axis_mask = params.shrink_axis_mask;
+  const auto* stop_indices = params.stop_indices;
+  const auto* strides = params.strides;
+
   // Begin with the specified index
   const bool shrink_axis = shrink_axis_mask & (1 << axis);
   int stop = stop_indices[axis];
@@ -103,7 +142,7 @@
   }
 
   // Handle negative indices
-  const int axis_size = input_shape[axis];
+  const int axis_size = input_shape.Dims(axis);
   if (stop < 0) {
     stop += axis_size;
   }
@@ -127,6 +166,31 @@
   return stride > 0 ? index >= stop : index <= stop;
 }
 
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+    int begin_mask, int end_mask, int shrink_axis_mask,
+    const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+    const std::vector<int>& strides) {
+  tflite::StridedSliceParams op_params;
+  const int dims_count = start_indices.size();
+
+  op_params.start_indices_count = dims_count;
+  op_params.stop_indices_count = dims_count;
+  op_params.strides_count = dims_count;
+  for (int i = 0; i < dims_count; ++i) {
+    op_params.start_indices[i] = start_indices[i];
+    op_params.stop_indices[i] = stop_indices[i];
+    op_params.strides[i] = strides[i];
+  }
+
+  op_params.begin_mask = begin_mask;
+  op_params.ellipsis_mask = 0;
+  op_params.end_mask = end_mask;
+  op_params.new_axis_mask = 0;
+  op_params.shrink_axis_mask = shrink_axis_mask;
+
+  return op_params;
+}
+
 }  // namespace strided_slice
 
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ee2af5b..1310645 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -17,44 +17,12 @@
 
 #include <complex>
 #include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
 #include "tensorflow/contrib/lite/kernels/internal/types.h"
 
 namespace tflite {
 
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int16_t* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
 template <>
 inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
   return tensor != nullptr
@@ -62,39 +30,6 @@
              : nullptr;
 }
 
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
-template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
-  return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
 template <>
 inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
   return tensor != nullptr
@@ -102,56 +37,14 @@
              : nullptr;
 }
 
-inline int RemapDim(int max_dimensions, int d) {
-  return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
-  Dims<4> d;
-  for (int i = 0; i < 4; ++i) {
-    int src = size - i - 1;
-    if (src >= 0) {
-      d.sizes[i] = data[src];
-    } else {
-      d.sizes[i] = 1;
-    }
-  }
-  d.strides[0] = 1;
-  for (int i = 1; i < 4; i++) {
-    d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
-  }
-  return d;
-}
-
 inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
   return GetTensorDims(data.data(), data.size());
 }
 
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
-  if (tensor == nullptr) {
-    return Dims<4>();
-  }
-
-  auto* dims = tensor->dims;
-  return GetTensorDims(dims->data, dims->size);
-}
-
 inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
   return RuntimeShape(data.size(), data.data());
 }
 
-inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
-  if (tensor == nullptr) {
-    return RuntimeShape();
-  }
-
-  auto* dims = tensor->dims;
-  return RuntimeShape(dims->size, dims->data);
-}
-
 // A list of tensors in a format that can be used by kernels like split and
 // concatenation.
 template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000..77e22a0
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+  return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+  return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+  Dims<4> d;
+  for (int i = 0; i < 4; ++i) {
+    int src = size - i - 1;
+    if (src >= 0) {
+      d.sizes[i] = data[src];
+    } else {
+      d.sizes[i] = 1;
+    }
+  }
+  d.strides[0] = 1;
+  for (int i = 1; i < 4; i++) {
+    d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+  }
+  return d;
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+  if (tensor == nullptr) {
+    return Dims<4>();
+  }
+
+  auto* dims = tensor->dims;
+  return GetTensorDims(dims->data, dims->size);
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+  if (tensor == nullptr) {
+    return RuntimeShape();
+  }
+
+  TfLiteIntArray* dims = tensor->dims;
+  const int dims_size = dims->size;
+  const int32_t* dims_data = dims->data;
+  return RuntimeShape(dims_size, dims_data);
+}
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 748356d..b0fe5ad 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 #if defined(_MSC_VER)
 #define __restrict__ __restrict
@@ -113,6 +113,10 @@
                                              const float* batch_vector,
                                              int n_batch, float* result);
 
+// Add another vector for each batch in the batch vector.
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+                          float* batch_vector);
+
 // Batch vector initialization with another vector.
 void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
                              float* batch_vector);
@@ -152,6 +156,12 @@
 // added to get one element of output.
 void ReductionSumVector(const float* input_vector, float* output_vector,
                         int output_size, int reduction_size);
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+                             int v_size, int n_batch,
+                             float normalization_epsilon);
 }  // namespace tensor_utils
 }  // namespace tflite
 
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 240fb64..6458af7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@
 ==============================================================================*/
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 #include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
 
 namespace tflite {
@@ -496,6 +496,16 @@
                   {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
 }
 
+TEST(uKernels, VectorBatchVectorAddTest) {
+  constexpr int kVectorSize = 3;
+  constexpr int kBatchSize = 2;
+  static float input[kVectorSize] = {0.0, -0.5, 1.0};
+  std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+  VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
+  EXPECT_THAT(output,
+              testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
+}
+
 TEST(uKernels, VectorBatchVectorAssignTest) {
   constexpr int kVectorSize = 5;
   constexpr int kBatchSize = 3;
@@ -712,5 +722,85 @@
   EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
 }
 
+TEST(uKernels, MeanStddevNormalizationNoneZeroInput) {
+  constexpr int kVectorSize = 4;
+  constexpr int kBatchSize = 2;
+  constexpr float kNormalizationEpsilon = 1e-8;
+
+  // None-zero input.
+  static float input[kVectorSize * kBatchSize] = {
+      0.1, 0.2, 0.3, 0.4,  // batch 0
+      0.9, 1.0, 1.1, 1.2,  // batch 1
+  };
+  std::vector<float> output(kVectorSize * kBatchSize);
+  MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+                          kNormalizationEpsilon);
+  const std::vector<float> expected_output = {
+      -1.34164071, -0.447213531, 0.44721365,  1.34164071,  // batch 0
+      -1.34163153, -0.447210163, 0.447211236, 1.3416326,   // batch 1
+  };
+  EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationAllZeroInput) {
+  constexpr int kVectorSize = 4;
+  constexpr int kBatchSize = 2;
+  constexpr float kNormalizationEpsilon = 1e-8;
+
+  // Zero input.
+  static float input[kVectorSize * kBatchSize] = {
+      0.0, 0.0, 0.0, 0.0,  // batch 0
+      0.0, 0.0, 0.0, 0.0,  // batch 1
+  };
+  std::vector<float> output(kVectorSize * kBatchSize);
+  MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+                          kNormalizationEpsilon);
+  const std::vector<float> expected_output = {
+      0.0, 0.0, 0.0, 0.0,  // batch 0
+      0.0, 0.0, 0.0, 0.0,  // batch 1
+  };
+  EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationMixed) {
+  constexpr int kVectorSize = 4;
+  constexpr int kBatchSize = 2;
+  constexpr float kNormalizationEpsilon = 1e-8;
+
+  // Mix of zero and non-zero input.
+  static float input[kVectorSize * kBatchSize] = {
+      0.0, 0.0, 0.0, 0.0,  // batch 0
+      0.1, 0.2, 0.3, 0.4,  // batch 1
+  };
+  std::vector<float> output(kVectorSize * kBatchSize);
+  MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+                          kNormalizationEpsilon);
+  const std::vector<float> expected_output = {
+      0.0,         0.0,          0.0,        0.0,         // batch 0
+      -1.34164071, -0.447213531, 0.44721365, 1.34164071,  // batch 1
+  };
+  EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationSmallValue) {
+  constexpr int kVectorSize = 4;
+  constexpr int kBatchSize = 2;
+  constexpr float kNormalizationEpsilon = 1e-8;
+
+  // Mix of zero and non-zero input.
+  static float input[kVectorSize * kBatchSize] = {
+      3e-5, -7e-6, -9e-5, 1e-6,  // batch 0
+      4e-5, 9e-6,  2e-4,  0.0,   // batch 1
+  };
+  std::vector<float> output(kVectorSize * kBatchSize);
+  MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+                          kNormalizationEpsilon);
+  const std::vector<float> expected_output = {
+      1.04231524,   0.212946132,  -1.64753067, 0.392269224,   // batch 0
+      -0.275023013, -0.658201098, 1.70267045,  -0.769446373,  // batch 1
+  };
+  EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
 }  // namespace tensor_utils
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 6ae4ebc..023707d 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -26,8 +26,8 @@
 enum class PaddingType : uint8 { kNone, kSame, kValid };
 
 struct PaddingValues {
-  int8 width;
-  int8 height;
+  int16 width;
+  int16 height;
 };
 
 // This enumeration allows for non-default formats for the weights array
@@ -720,12 +720,12 @@
 struct ComparisonParams {
   // uint8 inference params.
   int left_shift;
-  int32 input0_offset;
-  int32 input0_multiplier;
-  int input0_shift;
   int32 input1_offset;
   int32 input1_multiplier;
   int input1_shift;
+  int32 input2_offset;
+  int32 input2_multiplier;
+  int input2_shift;
   // Shape dependent / common to inference types.
   bool is_broadcast;
 };
@@ -734,10 +734,10 @@
   PaddingType padding_type;
   PaddingValues padding_values;
   // TODO(starka): This was just "stride", so check that width+height is OK.
-  int8 stride_width;
-  int8 stride_height;
-  int8 dilation_width_factor;
-  int8 dilation_height_factor;
+  int16 stride_width;
+  int16 stride_height;
+  int16 dilation_width_factor;
+  int16 dilation_height_factor;
   // uint8 inference params.
   // TODO(b/65838351): Use smaller types if appropriate.
   int32 input_offset;
@@ -745,8 +745,12 @@
   int32 output_offset;
   int32 output_multiplier;
   int output_shift;
-  int32 output_activation_min;
-  int32 output_activation_max;
+  // uint8, etc, activation params.
+  int32 quantized_activation_min;
+  int32 quantized_activation_max;
+  // float activation params.
+  float float_activation_min;
+  float float_activation_max;
 };
 
 struct DepthToSpaceParams {
@@ -756,8 +760,8 @@
 struct DepthwiseParams {
   PaddingType padding_type;
   PaddingValues padding_values;
-  int8 stride;
-  int8 depth_multiplier;
+  int16 stride;
+  int16 depth_multiplier;
   // uint8 inference params.
   // TODO(b/65838351): Use smaller types if appropriate.
   int32 input_offset;
@@ -765,8 +769,17 @@
   int32 output_offset;
   int32 output_multiplier;
   int output_shift;
-  int32 output_activation_min;
-  int32 output_activation_max;
+  // uint8, etc, activation params.
+  int32 quantized_activation_min;
+  int32 quantized_activation_max;
+  // float activation params.
+  float float_activation_min;
+  float float_activation_max;
+};
+
+struct DequantizationParams {
+  double scale;
+  int32 zero_point;
 };
 
 struct FakeQuantParams {
@@ -782,13 +795,17 @@
   int32 output_offset;
   int32 output_multiplier;
   int output_shift;
-  int32 output_activation_min;
-  int32 output_activation_max;
+  // uint8, etc, activation params.
+  int32 quantized_activation_min;
+  int32 quantized_activation_max;
+  // float activation params.
+  float float_activation_min;
+  float float_activation_max;
   FullyConnectedWeightsFormat weights_format;
 };
 
 struct GatherParams {
-  int8 input_rank;
+  int16 input_rank;
   int16 axis;
 };
 
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index ed46cd9..e9a5fd7 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -16,9 +16,10 @@
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
 
 #include <algorithm>
+#include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 5b3536d..e02d7df 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
new file mode 100644
index 0000000..1bbea67
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -0,0 +1,1316 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Layer Normalization LSTM op that applies normalization by mean and standard
+// deviation to the activation of the LSTM layers. Please see
+// https://arxiv.org/abs/1607.06450 for details.
+#include "flatbuffers/flexbuffers.h"  // flatbuffers
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace layer_norm_lstm {
+
+// Struct to hold Layer Norm LSTM option data.
+struct OpData {
+  TfLiteFusedActivation activation;
+  float cell_clip;
+  float proj_clip;
+  int scratch_tensor_index;
+};
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1;  // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9;    // Optional
+constexpr int kCellToForgetWeightsTensor = 10;  // Optional
+constexpr int kCellToOutputWeightsTensor = 11;  // Optional
+
+// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kInputLayerNormWeightsTensor = 12;
+constexpr int kForgetLayerNormWeightsTensor = 13;
+constexpr int kCellLayerNormWeightsTensor = 14;
+constexpr int kOutputLayerNormWeightsTensor = 15;
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 16;  // Optional
+constexpr int kForgetGateBiasTensor = 17;
+constexpr int kCellGateBiasTensor = 18;
+constexpr int kOutputGateBiasTensor = 19;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 20;  // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 21;  // Optional
+
+// State tensors.
+constexpr int kInputActivationStateTensor = 22;
+constexpr int kInputCellStateTensor = 23;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Total number of scratch tensors for hybrid Op.
+constexpr int kTensorsToAdd = 7;
+
+// Small float to avoid divergence during calculation of deviation.
+const float kLayerNormEpsilon = 1e-8;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+  auto* data = new OpData;
+
+  // Turn custom option data into flexbuffer map format.
+  const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+  const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+  // Get activation function, cell_clip and proj_clip from the flexbuffer.
+  // TODO(b/113824099): make activation more generic.
+  assert(m["fused_activation_function"].ToString() == "TANH");
+  data->activation = kTfLiteActTanh;
+  data->cell_clip = m["cell_clip"].AsFloat();
+  data->proj_clip = m["proj_clip"].AsFloat();
+
+  // Populate scratch_tensor_index.
+  context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
+                      &data->scratch_tensor_index);
+  return data;
+}
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+                                        TfLiteNode* node, int n_input,
+                                        int n_output, int n_cell) {
+  const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+  // Making sure clipping parameters have valid values.
+  // == 0 means no clipping
+  //  > 0 means clipping
+  TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
+  TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
+
+  const TfLiteTensor* input_to_input_weights =
+      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+  if (input_to_input_weights != nullptr) {
+    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+    TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+  }
+
+  const TfLiteTensor* input_to_forget_weights =
+      GetInput(context, node, kInputToForgetWeightsTensor);
+  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+  TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+  const TfLiteTensor* input_to_cell_weights =
+      GetInput(context, node, kInputToCellWeightsTensor);
+  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+  TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+  const TfLiteTensor* recurrent_to_input_weights =
+      GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+  if (recurrent_to_input_weights != nullptr) {
+    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+                      n_cell);
+    TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+                      n_output);
+  }
+
+  const TfLiteTensor* recurrent_to_forget_weights =
+      GetInput(context, node, kRecurrentToForgetWeightsTensor);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+                    n_cell);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+                    n_output);
+
+  const TfLiteTensor* recurrent_to_cell_weights =
+      GetInput(context, node, kRecurrentToCellWeightsTensor);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+                    n_output);
+
+  // We make sure the input-gate's parameters are either both present (regular
+  // LSTM) or not at all (CIFG-LSTM).
+  const bool cifg_weights_all_or_none =
+      ((input_to_input_weights != nullptr) &&
+       (recurrent_to_input_weights != nullptr)) ||
+      ((input_to_input_weights == nullptr) &&
+       (recurrent_to_input_weights == nullptr));
+  TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+  const TfLiteTensor* cell_to_input_weights =
+      GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+  if (cell_to_input_weights) {
+    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+    TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+  }
+
+  const TfLiteTensor* cell_to_forget_weights =
+      GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+  if (cell_to_forget_weights) {
+    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+    TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+  }
+
+  const TfLiteTensor* cell_to_output_weights =
+      GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+  if (cell_to_output_weights) {
+    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+    TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+  }
+
+  // Making sure the peephole weights are there all or none.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  const bool peephole_weights_all_or_none =
+      ((cell_to_input_weights != nullptr || use_cifg) &&
+       (cell_to_forget_weights != nullptr) &&
+       (cell_to_output_weights != nullptr)) ||
+      ((cell_to_input_weights == nullptr) &&
+       (cell_to_forget_weights == nullptr) &&
+       (cell_to_output_weights == nullptr));
+  TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+  // Making sure layer norm weights are not null and have the right dimension.
+  const TfLiteTensor* input_layer_norm_weights =
+      GetInput(context, node, kInputLayerNormWeightsTensor);
+  TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
+  TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
+
+  const TfLiteTensor* forget_layer_norm_weights =
+      GetInput(context, node, kForgetLayerNormWeightsTensor);
+  TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
+  TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
+
+  const TfLiteTensor* cell_layer_norm_weights =
+      GetInput(context, node, kCellLayerNormWeightsTensor);
+  TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
+  TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
+
+  const TfLiteTensor* output_layer_norm_weights =
+      GetInput(context, node, kOutputLayerNormWeightsTensor);
+  TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
+  TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
+
+  // Make sure the input gate bias is present only when not a CIFG-LSTM.
+  const TfLiteTensor* input_gate_bias =
+      GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+  if (use_cifg) {
+    TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+  } else {
+    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+    TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+  }
+
+  const TfLiteTensor* forget_gate_bias =
+      GetInput(context, node, kForgetGateBiasTensor);
+  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+  const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+  TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+  const TfLiteTensor* output_gate_bias =
+      GetInput(context, node, kOutputGateBiasTensor);
+  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+  TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+  const TfLiteTensor* projection_weights =
+      GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+  if (projection_weights != nullptr) {
+    TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+    TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+  }
+
+  const TfLiteTensor* projection_bias =
+      GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+  if (projection_bias != nullptr) {
+    TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+    TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+  }
+
+  // Making sure the projection tensors are consistent:
+  // 1) If projection weight is not present, then projection bias should not be
+  // present.
+  // 2) If projection weight is present, then projection bias is optional.
+  const bool projection_tensors_consistent =
+      ((projection_weights != nullptr) || (projection_bias == nullptr));
+  TF_LITE_ENSURE(context, projection_tensors_consistent == true);
+
+  return kTfLiteOk;
+}
+
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+  TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+  // Inferring batch size, number of outputs and number of cells from the
+  // input tensors.
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TF_LITE_ENSURE(context, input->dims->size > 1);
+  const int n_batch = input->dims->data[0];
+  const int n_input = input->dims->data[1];
+
+  const TfLiteTensor* input_to_output_weights =
+      GetInput(context, node, kInputToOutputWeightsTensor);
+  const int n_cell = input_to_output_weights->dims->data[0];
+  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+  const TfLiteTensor* recurrent_to_output_weights =
+      GetInput(context, node, kRecurrentToOutputWeightsTensor);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+  TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+                    n_cell);
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Check that input tensor dimensions matches with each other.
+  TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+                                                        n_output, n_cell));
+
+  // Get the pointer to output, activation_state and cell_state tensors.
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  const TfLiteTensor* activation_state =
+      GetInput(context, node, kInputActivationStateTensor);
+  const TfLiteTensor* cell_state =
+      GetInput(context, node, kInputCellStateTensor);
+
+  // Check the shape of input state tensors.
+  // These tensor may be 1D or 2D. It's fine as long as the total size is
+  // correct.
+  TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+  TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+  // Resize the output tensors.
+  TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+  output_size->data[0] = n_batch;
+  output_size->data[1] = n_output;
+  TF_LITE_ENSURE_OK(context,
+                    context->ResizeTensor(context, output, output_size));
+
+  // The weights are of consistent type, so it suffices to check one.
+  const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+                             input->type == kTfLiteFloat32);
+
+  TfLiteIntArrayFree(node->temporaries);
+  if (is_hybrid_op) {
+    node->temporaries = TfLiteIntArrayCreate(7);
+  } else {
+    node->temporaries = TfLiteIntArrayCreate(1);
+  }
+  node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+  // Create a scratch buffer tensor.
+  TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+  scratch_buffer->type = input->type;
+  scratch_buffer->allocation_type = kTfLiteArenaRw;
+
+  const TfLiteTensor* input_to_input_weights =
+      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+  scratch_buffer_size->data[0] = n_batch;
+  if (use_cifg) {
+    // Reserving space for Cell, Forget, Output gates
+    scratch_buffer_size->data[1] = n_cell * 3;
+  } else {
+    // Reserving space for Input, Cell, Forget, Output gates
+    scratch_buffer_size->data[1] = n_cell * 4;
+  }
+  TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+                                                   scratch_buffer_size));
+
+  if (is_hybrid_op) {
+    // Allocate temporary tensors to store quantized values of input,
+    // activation_state and cell_state tensors.
+    node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+    TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+    input_quantized->type = kTfLiteUInt8;
+    input_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+      TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+      TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+                                                       input_quantized_size));
+    }
+    node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+    TfLiteTensor* activation_state_quantized =
+        GetTemporary(context, node, /*index=*/2);
+    activation_state_quantized->type = kTfLiteUInt8;
+    activation_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+                             activation_state->dims)) {
+      TfLiteIntArray* activation_state_quantized_size =
+          TfLiteIntArrayCopy(activation_state->dims);
+      TF_LITE_ENSURE_OK(
+          context, context->ResizeTensor(context, activation_state_quantized,
+                                         activation_state_quantized_size));
+    }
+    node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+    TfLiteTensor* cell_state_quantized =
+        GetTemporary(context, node, /*index=*/3);
+    cell_state_quantized->type = kTfLiteUInt8;
+    cell_state_quantized->allocation_type = kTfLiteArenaRw;
+    if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+      TfLiteIntArray* cell_state_quantized_size =
+          TfLiteIntArrayCopy(cell_state->dims);
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, cell_state_quantized,
+                                              cell_state_quantized_size));
+    }
+
+    // Allocate temporary tensors to store scaling factors and product scaling
+    // factors. The latter is a convenience storage which allows to quantize
+    // a vector once (which produces the scaling factors) and multiply it with
+    // different matrices (which requires multiplying the scaling factors with
+    // the scaling factor of the matrix).
+    node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+    TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+    scaling_factors->type = kTfLiteFloat32;
+    scaling_factors->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+    scaling_factors_size->data[0] = n_batch;
+    if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+      TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+                                                       scaling_factors_size));
+    }
+    node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+    TfLiteTensor* prod_scaling_factors =
+        GetTemporary(context, node, /*index=*/5);
+    prod_scaling_factors->type = kTfLiteFloat32;
+    prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+    prod_scaling_factors_size->data[0] = n_batch;
+    if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+                             prod_scaling_factors_size)) {
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, prod_scaling_factors,
+                                              prod_scaling_factors_size));
+    }
+
+    // Allocate a temporary tensor to store the recovered weights. Since
+    // this is used for diagonal matrices, only need to store n_cell values.
+    node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+    TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
+    recovered_weights->type = kTfLiteFloat32;
+    recovered_weights->allocation_type = kTfLiteArenaRw;
+    TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
+    recovered_weights_size->data[0] = n_cell;
+    if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) {
+      TF_LITE_ENSURE_OK(context,
+                        context->ResizeTensor(context, recovered_weights,
+                                              recovered_weights_size));
+    }
+  }
+  return kTfLiteOk;
+}
+
+void LayerNormLstmStep(
+    const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+    const float* input_to_forget_weights_ptr,
+    const float* input_to_cell_weights_ptr,
+    const float* input_to_output_weights_ptr,
+    const float* recurrent_to_input_weights_ptr,
+    const float* recurrent_to_forget_weights_ptr,
+    const float* recurrent_to_cell_weights_ptr,
+    const float* recurrent_to_output_weights_ptr,
+    const float* cell_to_input_weights_ptr,
+    const float* cell_to_forget_weights_ptr,
+    const float* cell_to_output_weights_ptr,
+    const float* input_layer_norm_weight_ptr,
+    const float* forget_layer_norm_weight_ptr,
+    const float* cell_layer_norm_weight_ptr,
+    const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+    const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+    const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+    const float* projection_bias_ptr, float cell_clip, float proj_clip,
+    const TfLiteFusedActivation& activation, int n_batch, int n_cell,
+    int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
+    float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+    float* output_gate_scratch, float* output_ptr_batch) {
+  // Since we have already checked that weights are all there or none, we can
+  // check the existense of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+  const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+  // Initialize scratch buffers with 0.
+  if (!use_cifg) {
+    tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+  }
+  tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+  tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+  tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+  // For each batch and cell: compute input_weight * input.
+  if (!use_cifg) {
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+        input_gate_scratch, /*result_stride=*/1);
+  }
+
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+      forget_gate_scratch, /*result_stride=*/1);
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+      cell_scratch, /*result_stride=*/1);
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+      output_gate_scratch, /*result_stride=*/1);
+
+  // For each batch and cell: compute recurrent_weight * output_state.
+  if (!use_cifg) {
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+        n_batch, input_gate_scratch, /*result_stride=*/1);
+  }
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+      n_batch, forget_gate_scratch,
+      /*result_stride=*/1);
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+      n_batch, cell_scratch, /*result_stride=*/1);
+  tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+      recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+      n_batch, output_gate_scratch,
+      /*result_stride=*/1);
+
+  // For each batch and cell: update input gate.
+  if (!use_cifg) {
+    if (use_peephole) {
+      tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+          cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+          input_gate_scratch);
+    }
+    tensor_utils::MeanStddevNormalization(input_gate_scratch,
+                                          input_gate_scratch, n_cell, n_batch,
+                                          kLayerNormEpsilon);
+    tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+                                                n_cell, input_gate_scratch,
+                                                n_batch, input_gate_scratch);
+    tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+                                       input_gate_scratch);
+    tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+                                       input_gate_scratch);
+  }
+
+  // For each batch and cell: update forget gate.
+  if (use_peephole) {
+    tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+        forget_gate_scratch);
+  }
+  tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+                                        forget_gate_scratch, n_cell, n_batch,
+                                        kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+                                              n_cell, forget_gate_scratch,
+                                              n_batch, forget_gate_scratch);
+  tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+                                     forget_gate_scratch);
+  tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+                                     forget_gate_scratch);
+
+  // For each batch and cell: update the cell.
+  tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+                                        n_batch, kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(
+      cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+  tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+                                     cell_scratch);
+  tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+                                         n_batch * n_cell, cell_state_ptr);
+  tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+                                        activation, cell_scratch);
+  if (use_cifg) {
+    tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+                             forget_gate_scratch);
+    tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+  } else {
+    tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+  }
+  if (cell_clip > 0.0) {
+    tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+                             cell_state_ptr);
+  }
+
+  // For each batch and cell: update the output gate.
+  if (use_peephole) {
+    tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+        output_gate_scratch);
+  }
+  tensor_utils::MeanStddevNormalization(output_gate_scratch,
+                                        output_gate_scratch, n_cell, n_batch,
+                                        kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+                                              n_cell, output_gate_scratch,
+                                              n_batch, output_gate_scratch);
+  tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+                                     output_gate_scratch);
+  tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+                                     output_gate_scratch);
+  tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+                                        activation, cell_scratch);
+  tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+                                         n_batch * n_cell, output_gate_scratch);
+
+  // For each batch: update the projection and output_state.
+  const bool use_projection_weight = (projection_weights_ptr != nullptr);
+  const bool use_projection_bias = (projection_bias_ptr != nullptr);
+  if (use_projection_weight) {
+    if (use_projection_bias) {
+      tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+                                            n_batch, output_ptr_batch);
+    } else {
+      tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+        output_ptr_batch, /*result_stride=*/1);
+    if (proj_clip > 0.0) {
+      tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+                               output_ptr_batch);
+    }
+  } else {
+    tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+                             output_ptr_batch);
+  }
+  tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+                           output_state_ptr);
+}
+
+void LayerNormLstmStep(
+    const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+    float input_to_input_weights_scale,
+    const int8_t* input_to_forget_weights_ptr,
+    float input_to_forget_weights_scale,
+    const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+    const int8_t* input_to_output_weights_ptr,
+    float input_to_output_weights_scale,
+    const int8_t* recurrent_to_input_weights_ptr,
+    float recurrent_to_input_weights_scale,
+    const int8_t* recurrent_to_forget_weights_ptr,
+    float recurrent_to_forget_weights_scale,
+    const int8_t* recurrent_to_cell_weights_ptr,
+    float recurrent_to_cell_weights_scale,
+    const int8_t* recurrent_to_output_weights_ptr,
+    float recurrent_to_output_weights_scale,
+    const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+    const int8_t* cell_to_forget_weights_ptr,
+    float cell_to_forget_weights_scale,
+    const int8_t* cell_to_output_weights_ptr,
+    float cell_to_output_weights_scale,
+    const float* input_layer_norm_weight_ptr,
+    const float* forget_layer_norm_weight_ptr,
+    const float* cell_layer_norm_weight_ptr,
+    const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+    const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+    const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+    float projection_weights_scale, const float* projection_bias_ptr,
+    float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+    int n_batch, int n_cell, int n_input, int n_output,
+    float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+    float* output_gate_scratch, float* scaling_factors,
+    float* product_scaling_factors, float* recovered_weights,
+    int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+    int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+    float* cell_state_ptr, float* output_ptr_batch) {
+  // Since we have already checked that weights are all there or none, we can
+  // check the existense of only one to the get the condition.
+  const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+  const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+  // Initialize scratch buffers with 0.
+  if (!use_cifg) {
+    tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+  }
+  tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+  tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+  tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+  if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+    // Save quantization and matmul computation for all zero input.
+    float unused_min, unused_max;
+    for (int b = 0; b < n_batch; ++b) {
+      const int offset = b * n_input;
+      tensor_utils::SymmetricQuantizeFloats(
+          input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+          &unused_min, &unused_max, &scaling_factors[b]);
+    }
+    // For each batch and cell: compute input_weight * input.
+    if (!use_cifg) {
+      for (int b = 0; b < n_batch; ++b) {
+        product_scaling_factors[b] =
+            scaling_factors[b] * input_to_input_weights_scale;
+      }
+      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          input_to_input_weights_ptr, n_cell, n_input,
+          quantized_input_ptr_batch, product_scaling_factors, n_batch,
+          input_gate_scratch, /*result_stride=*/1);
+    }
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * input_to_forget_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+        product_scaling_factors, n_batch, forget_gate_scratch,
+        /*result_stride=*/1);
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * input_to_cell_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+        product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * input_to_output_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+        product_scaling_factors, n_batch, output_gate_scratch,
+        /*result_stride=*/1);
+  }
+
+  if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+    // Save quantization and matmul computation for all zero input.
+    float unused_min, unused_max;
+    for (int b = 0; b < n_batch; ++b) {
+      const int offset = b * n_output;
+      tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+                                            quantized_output_state_ptr + offset,
+                                            &unused_min, &unused_max,
+                                            &scaling_factors[b]);
+    }
+    // For each batch and cell: compute recurrent_weight * output_state.
+    if (!use_cifg) {
+      for (int b = 0; b < n_batch; ++b) {
+        product_scaling_factors[b] =
+            scaling_factors[b] * recurrent_to_input_weights_scale;
+      }
+      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          recurrent_to_input_weights_ptr, n_cell, n_output,
+          quantized_output_state_ptr, product_scaling_factors, n_batch,
+          input_gate_scratch, /*result_stride=*/1);
+    }
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * recurrent_to_forget_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        recurrent_to_forget_weights_ptr, n_cell, n_output,
+        quantized_output_state_ptr, product_scaling_factors, n_batch,
+        forget_gate_scratch, /*result_stride=*/1);
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * recurrent_to_cell_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        recurrent_to_cell_weights_ptr, n_cell, n_output,
+        quantized_output_state_ptr, product_scaling_factors, n_batch,
+        cell_scratch, /*result_stride=*/1);
+
+    for (int b = 0; b < n_batch; ++b) {
+      product_scaling_factors[b] =
+          scaling_factors[b] * recurrent_to_output_weights_scale;
+    }
+    tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+        recurrent_to_output_weights_ptr, n_cell, n_output,
+        quantized_output_state_ptr, product_scaling_factors, n_batch,
+        output_gate_scratch, /*result_stride=*/1);
+  }
+
+  // Save quantization and matmul computation for all zero input.
+  bool is_cell_state_all_zeros =
+      tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+  // For each batch and cell: update input gate.
+  if (!use_cifg) {
+    if (use_peephole && !is_cell_state_all_zeros) {
+      tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+                                         cell_to_input_weights_scale,
+                                         recovered_weights);
+      tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+          recovered_weights, n_cell, cell_state_ptr, n_batch,
+          input_gate_scratch);
+    }
+    tensor_utils::MeanStddevNormalization(input_gate_scratch,
+                                          input_gate_scratch, n_cell, n_batch,
+                                          kLayerNormEpsilon);
+    tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+                                                n_cell, input_gate_scratch,
+                                                n_batch, input_gate_scratch);
+    tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+                                       input_gate_scratch);
+    tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+                                       input_gate_scratch);
+  }
+
+  // For each batch and cell: update forget gate.
+  if (use_peephole && !is_cell_state_all_zeros) {
+    tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+                                       cell_to_forget_weights_scale,
+                                       recovered_weights);
+    tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        recovered_weights, n_cell, cell_state_ptr, n_batch,
+        forget_gate_scratch);
+  }
+  tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+                                        forget_gate_scratch, n_cell, n_batch,
+                                        kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+                                              n_cell, forget_gate_scratch,
+                                              n_batch, forget_gate_scratch);
+  tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+                                     forget_gate_scratch);
+  tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+                                     forget_gate_scratch);
+
+  // For each batch and cell: update the cell.
+  tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+                                        n_batch, kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(
+      cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+  tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+                                     cell_scratch);
+  tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+                                         n_batch * n_cell, cell_state_ptr);
+  tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+                                        activation, cell_scratch);
+  if (use_cifg) {
+    tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+                             forget_gate_scratch);
+    tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+  } else {
+    tensor_utils::VectorVectorCwiseProductAccumulate(
+        cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+  }
+  if (cell_clip > 0.0) {
+    tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+                             cell_state_ptr);
+  }
+
+  is_cell_state_all_zeros =
+      tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+  // For each batch and cell: update the output gate.
+  if (use_peephole && !is_cell_state_all_zeros) {
+    tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+                                       cell_to_output_weights_scale,
+                                       recovered_weights);
+    tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+        recovered_weights, n_cell, cell_state_ptr, n_batch,
+        output_gate_scratch);
+  }
+  tensor_utils::MeanStddevNormalization(output_gate_scratch,
+                                        output_gate_scratch, n_cell, n_batch,
+                                        kLayerNormEpsilon);
+  tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+                                              n_cell, output_gate_scratch,
+                                              n_batch, output_gate_scratch);
+  tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+                                     output_gate_scratch);
+  tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+                                     output_gate_scratch);
+  tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+                                        activation, cell_scratch);
+  tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+                                         n_batch * n_cell, output_gate_scratch);
+
+  // For each batch: update the projection and output_state.
+  const bool use_projection_weight = (projection_weights_ptr != nullptr);
+  const bool use_projection_bias = (projection_bias_ptr != nullptr);
+  if (use_projection_weight) {
+    if (use_projection_bias) {
+      tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+                                            n_batch, output_ptr_batch);
+    } else {
+      tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+    }
+    if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+      // Save quantization and matmul computation for all zero input.
+      float unused_min, unused_max;
+      for (int b = 0; b < n_batch; ++b) {
+        const int offset = b * n_cell;
+        tensor_utils::SymmetricQuantizeFloats(
+            output_gate_scratch + offset, n_cell,
+            quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+            &scaling_factors[b]);
+      }
+      for (int b = 0; b < n_batch; ++b) {
+        product_scaling_factors[b] =
+            scaling_factors[b] * projection_weights_scale;
+      }
+      tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+          projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+          product_scaling_factors, n_batch, output_ptr_batch,
+          /*result_stride=*/1);
+    }
+    if (proj_clip > 0.0) {
+      tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+                               output_ptr_batch);
+    }
+  } else {
+    tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+                             output_ptr_batch);
+  }
+  tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+                           output_state_ptr);
+}
+
+// The LayerNormLSTM Op engine.
+TfLiteStatus EvalFloat(
+    const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+    const TfLiteTensor* input_to_forget_weights,
+    const TfLiteTensor* input_to_cell_weights,
+    const TfLiteTensor* input_to_output_weights,
+    const TfLiteTensor* recurrent_to_input_weights,
+    const TfLiteTensor* recurrent_to_forget_weights,
+    const TfLiteTensor* recurrent_to_cell_weights,
+    const TfLiteTensor* recurrent_to_output_weights,
+    const TfLiteTensor* cell_to_input_weights,
+    const TfLiteTensor* cell_to_forget_weights,
+    const TfLiteTensor* cell_to_output_weights,
+    const TfLiteTensor* input_layer_norm_weights,
+    const TfLiteTensor* forget_layer_norm_weights,
+    const TfLiteTensor* cell_layer_norm_weights,
+    const TfLiteTensor* output_layer_norm_weights,
+    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+    float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+    TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+    TfLiteTensor* cell_state, TfLiteTensor* output) {
+  const int n_batch = input->dims->data[0];
+  const int n_input = input->dims->data[1];
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  const bool use_peephole = (cell_to_output_weights != nullptr);
+
+  float* input_gate_scratch = nullptr;
+  float* cell_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_scratch = scratch_buffer->data.f;
+    forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer->data.f;
+    cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+  }
+
+  // Check optional tensors, the respective pointers can be null.
+  const float* input_to_input_weights_ptr =
+      (use_cifg) ? nullptr : input_to_input_weights->data.f;
+  const float* recurrent_to_input_weights_ptr =
+      (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+  const float* input_gate_bias_ptr =
+      (use_cifg) ? nullptr : input_gate_bias->data.f;
+  const float* cell_to_input_weights_ptr =
+      (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+  const float* cell_to_forget_weights_ptr =
+      (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+  const float* cell_to_output_weights_ptr =
+      (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+  const float* projection_weights_ptr =
+      (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+  const float* projection_bias_ptr =
+      (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+  // Required tensors, pointers are non-null.
+  const float* input_ptr_batch = input->data.f;
+  const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+  const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+  const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+  const float* recurrent_to_forget_weights_ptr =
+      recurrent_to_forget_weights->data.f;
+  const float* recurrent_to_cell_weights_ptr =
+      recurrent_to_cell_weights->data.f;
+  const float* recurrent_to_output_weights_ptr =
+      recurrent_to_output_weights->data.f;
+  const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+  const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+  const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+  const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+  const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+  const float* cell_bias_ptr = cell_bias->data.f;
+  const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+  float* activation_state_ptr = activation_state->data.f;
+  float* cell_state_ptr = cell_state->data.f;
+  float* output_ptr_batch = output->data.f;
+
+  LayerNormLstmStep(
+      input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+      input_to_cell_weights_ptr, input_to_output_weights_ptr,
+      recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+      recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+      cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+      cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
+      forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
+      output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+      cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+      projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+      n_input, n_output, activation_state_ptr, cell_state_ptr,
+      input_gate_scratch, forget_gate_scratch, cell_scratch,
+      output_gate_scratch, output_ptr_batch);
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+    const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+    const TfLiteTensor* input_to_forget_weights,
+    const TfLiteTensor* input_to_cell_weights,
+    const TfLiteTensor* input_to_output_weights,
+    const TfLiteTensor* recurrent_to_input_weights,
+    const TfLiteTensor* recurrent_to_forget_weights,
+    const TfLiteTensor* recurrent_to_cell_weights,
+    const TfLiteTensor* recurrent_to_output_weights,
+    const TfLiteTensor* cell_to_input_weights,
+    const TfLiteTensor* cell_to_forget_weights,
+    const TfLiteTensor* cell_to_output_weights,
+    const TfLiteTensor* input_layer_norm_weights,
+    const TfLiteTensor* forget_layer_norm_weights,
+    const TfLiteTensor* cell_layer_norm_weights,
+    const TfLiteTensor* output_layer_norm_weights,
+    const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+    const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+    const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+    float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+    TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+    TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
+    TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
+    TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+    TfLiteTensor* cell_state, TfLiteTensor* output) {
+  const int n_batch = input->dims->data[0];
+  const int n_input = input->dims->data[1];
+  // n_cell and n_output will be the same size when there is no projection.
+  const int n_cell = input_to_output_weights->dims->data[0];
+  const int n_output = recurrent_to_output_weights->dims->data[1];
+
+  // Since we have already checked that weights are all there or none, we can
+  // check the existence of only one to get the condition.
+  const bool use_cifg = (input_to_input_weights == nullptr);
+  const bool use_peephole = (cell_to_output_weights != nullptr);
+
+  float* input_gate_scratch = nullptr;
+  float* cell_scratch = nullptr;
+  float* forget_gate_scratch = nullptr;
+  float* output_gate_scratch = nullptr;
+  if (use_cifg) {
+    cell_scratch = scratch_buffer->data.f;
+    forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+  } else {
+    input_gate_scratch = scratch_buffer->data.f;
+    cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+    forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+    output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+  }
+
+  // Check optional tensors, the respective pointers can be null.
+  int8_t* input_to_input_weights_ptr = nullptr;
+  float input_to_input_weights_scale = 1.0f;
+  int8_t* recurrent_to_input_weights_ptr = nullptr;
+  float recurrent_to_input_weights_scale = 1.0f;
+  float* input_gate_bias_ptr = nullptr;
+  if (!use_cifg) {
+    input_to_input_weights_ptr =
+        reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+    recurrent_to_input_weights_ptr =
+        reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+    input_gate_bias_ptr = input_gate_bias->data.f;
+    input_to_input_weights_scale = input_to_input_weights->params.scale;
+    recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+  }
+
+  int8_t* cell_to_input_weights_ptr = nullptr;
+  int8_t* cell_to_forget_weights_ptr = nullptr;
+  int8_t* cell_to_output_weights_ptr = nullptr;
+  float cell_to_input_weights_scale = 1.0f;
+  float cell_to_forget_weights_scale = 1.0f;
+  float cell_to_output_weights_scale = 1.0f;
+  if (use_peephole) {
+    if (!use_cifg) {
+      cell_to_input_weights_ptr =
+          reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+      cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+    }
+    cell_to_forget_weights_ptr =
+        reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+    cell_to_output_weights_ptr =
+        reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+    cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+    cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+  }
+
+  const int8_t* projection_weights_ptr =
+      (projection_weights == nullptr)
+          ? nullptr
+          : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+  const float projection_weights_scale =
+      (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+  const float* projection_bias_ptr =
+      (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+  // Required tensors, pointers are non-null.
+  const float* input_ptr_batch = input->data.f;
+  const int8_t* input_to_forget_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+  const float input_to_forget_weights_scale =
+      input_to_forget_weights->params.scale;
+  const int8_t* input_to_cell_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+  const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+  const int8_t* input_to_output_weights_ptr =
+      reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+  const float input_to_output_weights_scale =
+      input_to_output_weights->params.scale;
+  const int8_t* recurrent_to_forget_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+  const float recurrent_to_forget_weights_scale =
+      recurrent_to_forget_weights->params.scale;
+  const int8_t* recurrent_to_cell_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+  const float recurrent_to_cell_weights_scale =
+      recurrent_to_cell_weights->params.scale;
+  const int8_t* recurrent_to_output_weights_ptr =
+      reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+  const float recurrent_to_output_weights_scale =
+      recurrent_to_output_weights->params.scale;
+  const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+  const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+  const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+  const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+  const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+  const float* cell_bias_ptr = cell_bias->data.f;
+  const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+  float* activation_state_ptr = activation_state->data.f;
+  float* cell_state_ptr = cell_state->data.f;
+  float* output_ptr_batch = output->data.f;
+
+  // Temporary storage for quantized values and scaling factors.
+  int8_t* quantized_input_ptr =
+      reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+  int8_t* quantized_activation_state_ptr =
+      reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+  int8_t* quantized_cell_state_ptr =
+      reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+  float* scaling_factors_ptr = scaling_factors->data.f;
+  float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+  float* recovered_weights_ptr = recovered_weights->data.f;
+
+  LayerNormLstmStep(
+      input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+      input_to_forget_weights_ptr, input_to_forget_weights_scale,
+      input_to_cell_weights_ptr, input_to_cell_weights_scale,
+      input_to_output_weights_ptr, input_to_output_weights_scale,
+      recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+      recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+      recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+      recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+      cell_to_input_weights_ptr, cell_to_input_weights_scale,
+      cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+      cell_to_output_weights_ptr, cell_to_output_weights_scale,
+      input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
+      cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
+      input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+      output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+      projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+      n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
+      output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+      recovered_weights_ptr, quantized_input_ptr,
+      quantized_activation_state_ptr, quantized_cell_state_ptr,
+      activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+  return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+  const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+  const TfLiteTensor* input_to_input_weights =
+      GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+  const TfLiteTensor* input_to_forget_weights =
+      GetInput(context, node, kInputToForgetWeightsTensor);
+  const TfLiteTensor* input_to_cell_weights =
+      GetInput(context, node, kInputToCellWeightsTensor);
+  const TfLiteTensor* input_to_output_weights =
+      GetInput(context, node, kInputToOutputWeightsTensor);
+
+  const TfLiteTensor* recurrent_to_input_weights =
+      GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+  const TfLiteTensor* recurrent_to_forget_weights =
+      GetInput(context, node, kRecurrentToForgetWeightsTensor);
+  const TfLiteTensor* recurrent_to_cell_weights =
+      GetInput(context, node, kRecurrentToCellWeightsTensor);
+  const TfLiteTensor* recurrent_to_output_weights =
+      GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+  const TfLiteTensor* cell_to_input_weights =
+      GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+  const TfLiteTensor* cell_to_forget_weights =
+      GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+  const TfLiteTensor* cell_to_output_weights =
+      GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+  const TfLiteTensor* input_layer_norm_weights =
+      GetInput(context, node, kInputLayerNormWeightsTensor);
+  const TfLiteTensor* forget_layer_norm_weights =
+      GetInput(context, node, kForgetLayerNormWeightsTensor);
+  const TfLiteTensor* cell_layer_norm_weights =
+      GetInput(context, node, kCellLayerNormWeightsTensor);
+  const TfLiteTensor* output_layer_norm_weights =
+      GetInput(context, node, kOutputLayerNormWeightsTensor);
+
+  const TfLiteTensor* input_gate_bias =
+      GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+  const TfLiteTensor* forget_gate_bias =
+      GetInput(context, node, kForgetGateBiasTensor);
+  const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+  const TfLiteTensor* output_gate_bias =
+      GetInput(context, node, kOutputGateBiasTensor);
+
+  const TfLiteTensor* projection_weights =
+      GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+  const TfLiteTensor* projection_bias =
+      GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+  // Index the scratch buffers pointers to the global scratch buffer.
+  TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+  TfLiteTensor* activation_state =
+      &context->tensors[node->inputs->data[kInputActivationStateTensor]];
+  TfLiteTensor* cell_state =
+      &context->tensors[node->inputs->data[kInputCellStateTensor]];
+
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  switch (input_to_output_weights->type) {
+    case kTfLiteFloat32: {
+      return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+                       input_to_cell_weights, input_to_output_weights,
+                       recurrent_to_input_weights, recurrent_to_forget_weights,
+                       recurrent_to_cell_weights, recurrent_to_output_weights,
+                       cell_to_input_weights, cell_to_forget_weights,
+                       cell_to_output_weights, input_layer_norm_weights,
+                       forget_layer_norm_weights, cell_layer_norm_weights,
+                       output_layer_norm_weights, input_gate_bias,
+                       forget_gate_bias, cell_bias, output_gate_bias,
+                       projection_weights, projection_bias, op_data->cell_clip,
+                       op_data->proj_clip, op_data->activation, scratch_buffer,
+                       activation_state, cell_state, output);
+    }
+    case kTfLiteUInt8: {
+      TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+      TfLiteTensor* activation_state_quantized =
+          GetTemporary(context, node, /*index=*/2);
+      TfLiteTensor* cell_state_quantized =
+          GetTemporary(context, node, /*index=*/3);
+      TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+      TfLiteTensor* prod_scaling_factors =
+          GetTemporary(context, node, /*index=*/5);
+      TfLiteTensor* recovered_weights =
+          GetTemporary(context, node, /*index=*/6);
+      return EvalHybrid(
+          input, input_to_input_weights, input_to_forget_weights,
+          input_to_cell_weights, input_to_output_weights,
+          recurrent_to_input_weights, recurrent_to_forget_weights,
+          recurrent_to_cell_weights, recurrent_to_output_weights,
+          cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+          input_layer_norm_weights, forget_layer_norm_weights,
+          cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
+          forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+          projection_bias, op_data->cell_clip, op_data->proj_clip,
+          op_data->activation, scratch_buffer, scaling_factors,
+          prod_scaling_factors, recovered_weights, input_quantized,
+          activation_state_quantized, cell_state_quantized, activation_state,
+          cell_state, output);
+    }
+    default:
+      context->ReportError(context, "Type %d is not currently supported.",
+                           input_to_output_weights->type);
+      return kTfLiteError;
+  }
+  return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+  delete reinterpret_cast<OpData*>(buffer);
+}
+
+}  // namespace layer_norm_lstm
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM() {
+  static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
+                                 layer_norm_lstm::Prepare,
+                                 layer_norm_lstm::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
new file mode 100644
index 0000000..abc229f
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -0,0 +1,664 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Layer Norm LSTM op.
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h"  // flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LayerNormLSTMOpModel : public SingleOpModel {
+ public:
+  LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+                       bool use_cifg, bool use_peephole,
+                       bool use_projection_weights, bool use_projection_bias,
+                       float cell_clip, float proj_clip,
+                       const std::vector<std::vector<int>>& input_shapes,
+                       const TensorType& weight_type = TensorType_FLOAT32)
+      : n_batch_(n_batch),
+        n_input_(n_input),
+        n_cell_(n_cell),
+        n_output_(n_output) {
+    input_ = AddInput(TensorType_FLOAT32);
+
+    if (use_cifg) {
+      input_to_input_weights_ = AddNullInput();
+    } else {
+      input_to_input_weights_ = AddInput(weight_type);
+    }
+
+    input_to_forget_weights_ = AddInput(weight_type);
+    input_to_cell_weights_ = AddInput(weight_type);
+    input_to_output_weights_ = AddInput(weight_type);
+
+    if (use_cifg) {
+      recurrent_to_input_weights_ = AddNullInput();
+    } else {
+      recurrent_to_input_weights_ = AddInput(weight_type);
+    }
+
+    recurrent_to_forget_weights_ = AddInput(weight_type);
+    recurrent_to_cell_weights_ = AddInput(weight_type);
+    recurrent_to_output_weights_ = AddInput(weight_type);
+
+    if (use_peephole) {
+      if (use_cifg) {
+        cell_to_input_weights_ = AddNullInput();
+      } else {
+        cell_to_input_weights_ = AddInput(weight_type);
+      }
+      cell_to_forget_weights_ = AddInput(weight_type);
+      cell_to_output_weights_ = AddInput(weight_type);
+    } else {
+      cell_to_input_weights_ = AddNullInput();
+      cell_to_forget_weights_ = AddNullInput();
+      cell_to_output_weights_ = AddNullInput();
+    }
+
+    input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+    forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+    cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+    output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+
+    if (use_cifg) {
+      input_gate_bias_ = AddNullInput();
+    } else {
+      input_gate_bias_ = AddInput(TensorType_FLOAT32);
+    }
+    forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+    cell_bias_ = AddInput(TensorType_FLOAT32);
+    output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+    if (use_projection_weights) {
+      projection_weights_ = AddInput(weight_type);
+      if (use_projection_bias) {
+        projection_bias_ = AddInput(TensorType_FLOAT32);
+      } else {
+        projection_bias_ = AddNullInput();
+      }
+    } else {
+      projection_weights_ = AddNullInput();
+      projection_bias_ = AddNullInput();
+    }
+
+    // Adding the 2 state tensors.
+    output_state_ =
+        AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+    cell_state_ =
+        AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
+    output_ = AddOutput(TensorType_FLOAT32);
+
+    // Set up and pass in custom options using flexbuffer.
+    flexbuffers::Builder fbb;
+    fbb.Map([&]() {
+      fbb.Int("cell_clip", cell_clip);
+      fbb.Int("proj_clip", proj_clip);
+      fbb.String("fused_activation_function", "TANH");
+    });
+    fbb.Finish();
+    SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
+    BuildInterpreter(input_shapes);
+  }
+
+  void SetInputToInputWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_to_input_weights_, f);
+  }
+
+  void SetInputToForgetWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_to_forget_weights_, f);
+  }
+
+  void SetInputToCellWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_to_cell_weights_, f);
+  }
+
+  void SetInputToOutputWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_to_output_weights_, f);
+  }
+
+  void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+    PopulateTensor(recurrent_to_input_weights_, f);
+  }
+
+  void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+    PopulateTensor(recurrent_to_forget_weights_, f);
+  }
+
+  void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+    PopulateTensor(recurrent_to_cell_weights_, f);
+  }
+
+  void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+    PopulateTensor(recurrent_to_output_weights_, f);
+  }
+
+  void SetCellToInputWeights(std::initializer_list<float> f) {
+    PopulateTensor(cell_to_input_weights_, f);
+  }
+
+  void SetCellToForgetWeights(std::initializer_list<float> f) {
+    PopulateTensor(cell_to_forget_weights_, f);
+  }
+
+  void SetCellToOutputWeights(std::initializer_list<float> f) {
+    PopulateTensor(cell_to_output_weights_, f);
+  }
+
+  void SetInputLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_layer_norm_weights_, f);
+  }
+
+  void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(forget_layer_norm_weights_, f);
+  }
+
+  void SetCellLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(cell_layer_norm_weights_, f);
+  }
+
+  void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(output_layer_norm_weights_, f);
+  }
+
+  void SetInputGateBias(std::initializer_list<float> f) {
+    PopulateTensor(input_gate_bias_, f);
+  }
+
+  void SetForgetGateBias(std::initializer_list<float> f) {
+    PopulateTensor(forget_gate_bias_, f);
+  }
+
+  void SetCellBias(std::initializer_list<float> f) {
+    PopulateTensor(cell_bias_, f);
+  }
+
+  void SetOutputGateBias(std::initializer_list<float> f) {
+    PopulateTensor(output_gate_bias_, f);
+  }
+
+  void SetProjectionWeights(std::initializer_list<float> f) {
+    PopulateTensor(projection_weights_, f);
+  }
+
+  void SetProjectionBias(std::initializer_list<float> f) {
+    PopulateTensor(projection_bias_, f);
+  }
+
+  void SetInput(int offset, const float* begin, const float* end) {
+    PopulateTensor(input_, offset, const_cast<float*>(begin),
+                   const_cast<float*>(end));
+  }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+  int num_inputs() { return n_input_; }
+  int num_outputs() { return n_output_; }
+  int num_cells() { return n_cell_; }
+  int num_batches() { return n_batch_; }
+
+ protected:
+  int input_;
+  int input_to_input_weights_;
+  int input_to_forget_weights_;
+  int input_to_cell_weights_;
+  int input_to_output_weights_;
+
+  int recurrent_to_input_weights_;
+  int recurrent_to_forget_weights_;
+  int recurrent_to_cell_weights_;
+  int recurrent_to_output_weights_;
+
+  int cell_to_input_weights_;
+  int cell_to_forget_weights_;
+  int cell_to_output_weights_;
+
+  int input_layer_norm_weights_;
+  int forget_layer_norm_weights_;
+  int cell_layer_norm_weights_;
+  int output_layer_norm_weights_;
+
+  int input_gate_bias_;
+  int forget_gate_bias_;
+  int cell_bias_;
+  int output_gate_bias_;
+
+  int projection_weights_;
+  int projection_bias_;
+
+  int output_state_;
+  int cell_state_;
+
+  int output_;
+
+  int n_batch_;
+  int n_input_;
+  int n_cell_;
+  int n_output_;
+};
+
+class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
+ public:
+  HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+                             bool use_cifg, bool use_peephole,
+                             bool use_projection_weights,
+                             bool use_projection_bias, float cell_clip,
+                             float proj_clip,
+                             const std::vector<std::vector<int>>& input_shapes)
+      : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
+                             use_peephole, use_projection_weights,
+                             use_projection_bias, cell_clip, proj_clip,
+                             input_shapes, TensorType_UINT8) {}
+
+  void SetInputToInputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+  }
+
+  void SetInputToForgetWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+  }
+
+  void SetInputToCellWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+  }
+
+  void SetInputToOutputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+  }
+
+  void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+  }
+
+  void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+  }
+
+  void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+  }
+
+  void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+  }
+
+  void SetCellToInputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+  }
+
+  void SetCellToForgetWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+  }
+
+  void SetCellToOutputWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+  }
+
+  void SetInputLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(input_layer_norm_weights_, f);
+  }
+
+  void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(forget_layer_norm_weights_, f);
+  }
+
+  void SetCellLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(cell_layer_norm_weights_, f);
+  }
+
+  void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+    PopulateTensor(output_layer_norm_weights_, f);
+  }
+
+  void SetProjectionWeights(std::initializer_list<float> f) {
+    SymmetricQuantizeAndPopulate(projection_weights_, f);
+  }
+};
+
+class BaseLayerNormLstmTest : public ::testing::Test {
+ protected:
+  // Weights of the Layer Norm LSTM model. Some are optional.
+  std::initializer_list<float> input_to_input_weights_;
+  std::initializer_list<float> input_to_cell_weights_;
+  std::initializer_list<float> input_to_forget_weights_;
+  std::initializer_list<float> input_to_output_weights_;
+  std::initializer_list<float> input_gate_bias_;
+  std::initializer_list<float> cell_gate_bias_;
+  std::initializer_list<float> forget_gate_bias_;
+  std::initializer_list<float> output_gate_bias_;
+  std::initializer_list<float> recurrent_to_input_weights_;
+  std::initializer_list<float> recurrent_to_cell_weights_;
+  std::initializer_list<float> recurrent_to_forget_weights_;
+  std::initializer_list<float> recurrent_to_output_weights_;
+  std::initializer_list<float> cell_to_input_weights_;
+  std::initializer_list<float> cell_to_forget_weights_;
+  std::initializer_list<float> cell_to_output_weights_;
+  std::initializer_list<float> input_layer_norm_weights_;
+  std::initializer_list<float> forget_layer_norm_weights_;
+  std::initializer_list<float> cell_layer_norm_weights_;
+  std::initializer_list<float> output_layer_norm_weights_;
+  std::initializer_list<float> projection_weights_;
+
+  // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
+  std::vector<std::vector<float>> layer_norm_lstm_input_;
+
+  // Compares output up to tolerance to the result of the layer_norm_lstm given
+  // the input.
+  void VerifyGoldens(const std::vector<std::vector<float>>& input,
+                     const std::vector<std::vector<float>>& output,
+                     LayerNormLSTMOpModel* layer_norm_lstm,
+                     float tolerance = 1e-5) {
+    const int num_batches = input.size();
+    EXPECT_GT(num_batches, 0);
+    const int num_inputs = layer_norm_lstm->num_inputs();
+    EXPECT_GT(num_inputs, 0);
+    const int input_sequence_size = input[0].size() / num_inputs;
+    EXPECT_GT(input_sequence_size, 0);
+    for (int i = 0; i < input_sequence_size; ++i) {
+      for (int b = 0; b < num_batches; ++b) {
+        const float* batch_start = input[b].data() + i * num_inputs;
+        const float* batch_end = batch_start + num_inputs;
+
+        layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
+                                  batch_start, batch_end);
+      }
+
+      layer_norm_lstm->Invoke();
+
+      const int num_outputs = layer_norm_lstm->num_outputs();
+      std::vector<float> expected;
+      for (int b = 0; b < num_batches; ++b) {
+        const float* golden_start_batch = output[b].data() + i * num_outputs;
+        const float* golden_end_batch = golden_start_batch + num_outputs;
+        expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+      }
+      EXPECT_THAT(layer_norm_lstm->GetOutput(),
+                  ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+    }
+  }
+};
+
+class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
+    : public BaseLayerNormLstmTest {
+  void SetUp() override {
+    input_to_input_weights_ = {0.5,  0.6,  0.7,  -0.8, -0.9, 0.1,  0.2,
+                               0.3,  -0.4, 0.5,  -0.8, 0.7,  -0.6, 0.5,
+                               -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+    input_to_forget_weights_ = {-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2,
+                                -0.4, 0.3,  -0.8, -0.4, 0.3,  -0.5, -0.4,
+                                -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5};
+
+    input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5,  -0.2,
+                              -0.3, -0.2, -0.6, 0.6,  -0.1, -0.4, -0.3,
+                              -0.7, 0.7,  -0.9, -0.5, 0.8,  0.6};
+
+    input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
+                                -0.3, -0.8, -0.2, 0.6,  -0.2, 0.4,  -0.7,
+                                -0.3, -0.5, 0.1,  0.5,  -0.6, -0.4};
+
+    input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
+
+    forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
+
+    cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
+
+    output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
+
+    recurrent_to_input_weights_ = {-0.2, -0.3, 0.4,  0.1,  -0.5, 0.9,
+                                   -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+    recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8,  -0.08,
+                                  -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+    recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
+                                    0.9,  0.3,  -0.1, 0.2,  0.5, 0.2};
+
+    recurrent_to_output_weights_ = {0.3,  -0.1, 0.1,  -0.2, -0.5, -0.7,
+                                    -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+    cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
+
+    cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
+
+    cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
+
+    input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
+    forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
+    cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
+    output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
+
+    projection_weights_ = {-0.1, 0.2,  0.01, -0.2, 0.1,  0.5,
+                           0.3,  0.08, 0.07, 0.2,  -0.4, 0.2};
+
+    layer_norm_lstm_input_ = {
+        {// Batch0: 3 (input_sequence_size) * 5 (n_input)
+         0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
+         0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
+         0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
+
+        {// Batch1: 3 (input_sequence_size) * 5 (n_input)
+         0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
+         0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
+         0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
+    };
+  }
+};
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+       LayerNormLstmBlackBoxTest) {
+  const int n_batch = 2;
+  const int n_input = 5;
+  const int n_cell = 4;
+  const int n_output = 3;
+  const float ceil_clip = 0.0;
+  const float proj_clip = 0.0;
+
+  LayerNormLSTMOpModel layer_norm_lstm(
+      n_batch, n_input, n_cell, n_output,
+      /*use_cifg=*/false, /*use_peephole=*/true,
+      /*use_projection_weights=*/true,
+      /*use_projection_bias=*/false, ceil_clip, proj_clip,
+      {
+          {n_batch, n_input},  // input tensor
+
+          {n_cell, n_input},  // input_to_input_weight tensor
+          {n_cell, n_input},  // input_to_forget_weight tensor
+          {n_cell, n_input},  // input_to_cell_weight tensor
+          {n_cell, n_input},  // input_to_output_weight tensor
+
+          {n_cell, n_output},  // recurrent_to_input_weight tensor
+          {n_cell, n_output},  // recurrent_to_forget_weight tensor
+          {n_cell, n_output},  // recurrent_to_cell_weight tensor
+          {n_cell, n_output},  // recurrent_to_output_weight tensor
+
+          {n_cell},  // cell_to_input_weight tensor
+          {n_cell},  // cell_to_forget_weight tensor
+          {n_cell},  // cell_to_output_weight tensor
+
+          {n_cell},  // input_layer_norm_weight tensor
+          {n_cell},  // forget_layer_norm_weight tensor
+          {n_cell},  // cell_layer_norm_weight tensor
+          {n_cell},  // output_layer_norm_weight tensor
+
+          {n_cell},  // input_gate_bias tensor
+          {n_cell},  // forget_gate_bias tensor
+          {n_cell},  // cell_bias tensor
+          {n_cell},  // output_gate_bias tensor
+
+          {n_output, n_cell},  // projection_weight tensor
+          {0},                 // projection_bias tensor
+      });
+
+  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+  layer_norm_lstm.SetCellBias(cell_gate_bias_);
+  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+  layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+  layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+  layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+  layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+  // Verify the final output.
+  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+      {
+          // Batch0: 3 (input_sequence_size) * 3 (n_output)
+          0.0244077, 0.128027, -0.00170918,  // seq 0
+          0.0137642, 0.140751, 0.0395835,    // seq 1
+          -0.00459231, 0.155278, 0.0837377,  // seq 2
+      },
+      {
+          // Batch1: 3 (input_sequence_size) * 3 (n_output)
+          -0.00692428, 0.0848741, 0.063445,  // seq 0
+          -0.00403912, 0.139963, 0.072681,   // seq 1
+          0.00752706, 0.161903, 0.0561371,   // seq 2
+      }};
+
+  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+                &layer_norm_lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+       HybridLayerNormLstmBlackBoxTest) {
+  const int n_batch = 2;
+  const int n_input = 5;
+  const int n_cell = 4;
+  const int n_output = 3;
+  const float ceil_clip = 0.0;
+  const float proj_clip = 0.0;
+
+  HybridLayerNormLSTMOpModel layer_norm_lstm(
+      n_batch, n_input, n_cell, n_output,
+      /*use_cifg=*/false, /*use_peephole=*/true,
+      /*use_projection_weights=*/true,
+      /*use_projection_bias=*/false, ceil_clip, proj_clip,
+      {
+          {n_batch, n_input},  // input tensor
+
+          {n_cell, n_input},  // input_to_input_weight tensor
+          {n_cell, n_input},  // input_to_forget_weight tensor
+          {n_cell, n_input},  // input_to_cell_weight tensor
+          {n_cell, n_input},  // input_to_output_weight tensor
+
+          {n_cell, n_output},  // recurrent_to_input_weight tensor
+          {n_cell, n_output},  // recurrent_to_forget_weight tensor
+          {n_cell, n_output},  // recurrent_to_cell_weight tensor
+          {n_cell, n_output},  // recurrent_to_output_weight tensor
+
+          {n_cell},  // cell_to_input_weight tensor
+          {n_cell},  // cell_to_forget_weight tensor
+          {n_cell},  // cell_to_output_weight tensor
+
+          {n_cell},  // input_layer_norm_weight tensor
+          {n_cell},  // forget_layer_norm_weight tensor
+          {n_cell},  // cell_layer_norm_weight tensor
+          {n_cell},  // output_layer_norm_weight tensor
+
+          {n_cell},  // input_gate_bias tensor
+          {n_cell},  // forget_gate_bias tensor
+          {n_cell},  // cell_bias tensor
+          {n_cell},  // output_gate_bias tensor
+
+          {n_output, n_cell},  // projection_weight tensor
+          {0},                 // projection_bias tensor
+      });
+
+  layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+  layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+  layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+  layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+  layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+  layer_norm_lstm.SetCellBias(cell_gate_bias_);
+  layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+  layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+  layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+  layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+  layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+  layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+  layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+  layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+  layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+  layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+  layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+  layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+  layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+  layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+  const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+      {
+          // Batch0: 3 (input_sequence_size) * 3 (n_output)
+          0.0244576, 0.127847, -0.00181765,  // seq 0
+          0.0137518, 0.140892, 0.0402234,    // seq 1
+          -0.0048839, 0.155096, 0.0840309,   // seq 2
+      },
+      {
+          // Batch1: 3 (input_sequence_size) * 3 (n_output)
+          -0.00728636, 0.0843957, 0.0634786,  // seq 0
+          -0.00448382, 0.139278, 0.0737372,   // seq 1
+          0.00734616, 0.161793, 0.0560238,    // seq 2
+      }};
+
+  VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+                &layer_norm_lstm);
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 799c152..334d2a2 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index c71f3b4..f770cb3 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 6ed9385..6edb646 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -59,8 +59,8 @@
 #include <limits>
 #include <memory>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 #include "util/hash/farmhash.h"
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 74dc3f2..aaa3ce9 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -20,8 +20,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 0308a39..7cb0146 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 306f676..66cf147 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
 #include "flatbuffers/flexbuffers.h"  // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
 #include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 92d8bc8..e0aac8a 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 4124c05..0ddd064 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
index 9ff3dca..910aed6 100644
--- a/tensorflow/contrib/lite/kernels/one_hot.cc
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index cc326a7..4cb98fd 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 55bcf3b..0d93940 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -92,8 +92,8 @@
                       op_context.constant_values->type);
   }
 
-  // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
-  TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+  // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+  TF_LITE_ENSURE(context, op_context.dims <= 4);
 
   // Exit early if paddings is a non-const tensor. Set output tensor to
   // dynamic so output size can be determined in Eval.
@@ -134,21 +134,21 @@
     after_padding.push_back(paddings_data[idx * 2 + 1]);
   }
 
-#define TF_LITE_PAD(type, scalar, pad_value)                          \
-  TF_LITE_ENSURE_EQ(context, before_padding.size(), 4);               \
-  TF_LITE_ENSURE_EQ(context, after_padding.size(), 4);                \
-  tflite::PadParams op_params;                                        \
-  op_params.left_padding_count = 4;                                   \
-  op_params.right_padding_count = 4;                                  \
-  for (int i = 0; i < 4; ++i) {                                       \
-    op_params.left_padding[i] = before_padding[3 - i];                \
-    op_params.right_padding[i] = after_padding[3 - i];                \
-  }                                                                   \
-  const scalar pad_value_copy = pad_value;                            \
-                                                                      \
-  type::Pad(op_params, GetTensorShape(op_context.input),              \
-            GetTensorData<scalar>(op_context.input), &pad_value_copy, \
-            GetTensorShape(op_context.output),                        \
+#define TF_LITE_PAD(type, scalar, pad_value)                             \
+  TF_LITE_ENSURE(context, before_padding.size() <= 4);                   \
+  TF_LITE_ENSURE(context, after_padding.size() <= 4);                    \
+  tflite::PadParams op_params;                                           \
+  op_params.left_padding_count = before_padding.size();                  \
+  op_params.right_padding_count = after_padding.size();                  \
+  for (int i = 0; i < op_context.dims; ++i) {                            \
+    op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
+    op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
+  }                                                                      \
+  const scalar pad_value_copy = pad_value;                               \
+                                                                         \
+  type::Pad(op_params, GetTensorShape(op_context.input),                 \
+            GetTensorData<scalar>(op_context.input), &pad_value_copy,    \
+            GetTensorShape(op_context.output),                           \
             GetTensorData<scalar>(op_context.output))
   switch (op_context.input->type) {
     case kTfLiteFloat32: {
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f8b9064..f663899 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -193,7 +193,7 @@
       PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
                       {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
                       {TensorType_FLOAT32}),
-      "dims != 4");
+      "dims <= 4");
 }
 
 TEST(PadOpTest, UnequalDimensions) {
@@ -221,6 +221,15 @@
   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
 }
 
+TEST(PadOpTest, SimpleConst1DTest) {
+  PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2},
+                    {TensorType_FLOAT32});
+  m.SetInput({2, 3});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0}));
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
 TEST(PadOpTest, SimpleDynamicTest) {
   PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
                       {TensorType_FLOAT32});
@@ -334,7 +343,7 @@
                    {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
                    {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
                    {TensorType_FLOAT32}),
-               "dims != 4");
+               "dims <= 4");
 }
 
 TEST(PadV2OpTest, UnequalDimensions) {
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f1..42b6b45 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 29a5be0..6451142 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index d676de5..1e96cc8 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index ca83797..d94d821 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -15,8 +15,8 @@
 #include <string.h>
 #include <limits>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 30ca752..1debf10 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -22,8 +22,10 @@
 namespace custom {
 
 TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
 TfLiteRegistration* Register_MFCC();
 TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+TfLiteRegistration* Register_RELU_1();
 
 }  // namespace custom
 
@@ -116,6 +118,7 @@
 TfLiteRegistration* Register_LOGICAL_NOT();
 TfLiteRegistration* Register_UNPACK();
 TfLiteRegistration* Register_FLOOR_DIV();
+TfLiteRegistration* Register_SQUARE();
 
 TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
   context->ReportError(
@@ -241,6 +244,7 @@
   AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
   AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
   AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
+  AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
 
 #if 0
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
@@ -248,6 +252,8 @@
   AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
   AddCustom("AudioSpectrogram",
             tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+  AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
+  AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
   AddCustom("TFLite_Detection_PostProcess",
             tflite::ops::custom::Register_DETECTION_POSTPROCESS());
 #endif
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 0296152..61856ab 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@
 #define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
 
 #include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
 
 namespace tflite {
 namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc
new file mode 100644
index 0000000..abafee2
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace relu1 {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  const TfLiteTensor* input = GetInput(context, node, 0);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  TfLiteTensor* output = GetOutput(context, node, 0);
+  output->type = input->type;
+  return context->ResizeTensor(context, output,
+                               TfLiteIntArrayCopy(input->dims));
+}
+
+// This is derived from lite/kernels/activations.cc.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  const TfLiteTensor* input = GetInput(context, node, 0);
+  TfLiteTensor* output = GetOutput(context, node, 0);
+  const int elements = NumElements(input);
+  const float* in = input->data.f;
+  const float* in_end = in + elements;
+  float* out = output->data.f;
+  for (; in < in_end; ++in, ++out) {
+    *out = std::min(std::max(0.f, *in), 1.f);
+  }
+  return kTfLiteOk;
+}
+
+}  // namespace relu1
+
+TfLiteRegistration* Register_RELU_1() {
+  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+                                 relu1::Prepare, relu1::Eval};
+  return &r;
+}
+
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
new file mode 100644
index 0000000..c1e0149
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h"  // flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_RELU_1();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+  explicit BaseActivationsOpModel(const TensorData& input) {
+    input_ = AddInput(input);
+    output_ = AddOutput({input.type, {}});
+    flexbuffers::Builder fbb;
+    fbb.Map([&]() {});
+    fbb.Finish();
+    SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
+    BuildInterpreter({GetShape(input_)});
+  }
+
+ protected:
+  int input_;
+  int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+  using BaseActivationsOpModel::BaseActivationsOpModel;
+
+  void SetInput(std::initializer_list<float> data) {
+    PopulateTensor(input_, data);
+  }
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(FloatActivationsOpTest, Relu1) {
+  FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+  m.SetInput({
+      0.0, -0.6, 0.2, -0.4,  //
+      0.3, -2.0, 1.1, -0.1,  //
+  });
+  m.Invoke();
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+                                 0.0, 0.0, 0.2, 0.0,  //
+                                 0.3, 0.0, 1.0, 0.0,  //
+                             }));
+}
+
+}  // namespace
+}  // namespace custom
+}  // namespace ops
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 49ba057..f41147b 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 #include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index dafa3ae..fb045d1 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3cdb5db..3959502 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
index dbcd2ef..66d4c9e 100644
--- a/tensorflow/contrib/lite/kernels/shape.cc
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b..de80a40 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@
 #include <string>
 #include <vector>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
 #include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 55e1650..ccfee41 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@
 #include <string.h>
 #include <cmath>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 8332ae3..3a10d2e 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9238e87..64c56c0 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index fec2a6f..178568e 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b144486..719e2dc 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662..080c51c 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index bed2117..87ffcc4 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@
 #include <string.h>
 #include <cmath>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 77a1f59..1be0c83 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 6ba7959..9903fd5 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -23,8 +23,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 9156917..0fdb0a3 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -74,8 +74,8 @@
       CustomOptionsFormat_FLEXBUFFERS));
 }
 
-void SingleOpModel::BuildInterpreter(
-    std::vector<std::vector<int>> input_shapes) {
+void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+                                     bool allow_fp32_relax_to_fp16) {
   auto opcodes = builder_.CreateVector(opcodes_);
   auto operators = builder_.CreateVector(operators_);
   auto tensors = builder_.CreateVector(tensors_);
@@ -113,6 +113,8 @@
     CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
   }
 
+  interpreter_->SetAllowFp16PrecisionForFp32(allow_fp32_relax_to_fp16);
+
   // Modify delegate with function.
   if (apply_delegate_fn_) {
     apply_delegate_fn_(interpreter_.get());
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index bedbe93..84deb0e 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -182,7 +182,8 @@
 
   // Build the interpreter for this model. Also, resize and allocate all
   // tensors given the shapes of the inputs.
-  void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+  void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+                        bool allow_fp32_relax_to_fp16 = false);
 
   void Invoke();
 
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index 5181a8f..49421eb 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
index 4f78c22..e73ca7b 100644
--- a/tensorflow/contrib/lite/kernels/tile_test.cc
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 2dd760b..6c38b6739 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <algorithm>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
 #include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 2abb89b..16106fd 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@
 limitations under the License.
 ==============================================================================*/
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 800b056..9535996 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@
 ==============================================================================*/
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index a9baa5c..6f2d98e 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index c678f14..63817bd 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0180c2c..744ee7c 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@
 #include <iostream>
 #include <limits>
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/activation_functor.h"
 #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
index 4998f88..9ff06f8 100644
--- a/tensorflow/contrib/lite/kernels/unpack.cc
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -13,8 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
 #include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index 0294ec8..2d4707f 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -15,7 +15,7 @@
 #ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
 #define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
index fa9a3cd..92934d1 100644
--- a/tensorflow/contrib/lite/mmap_allocation.cc
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -20,7 +20,7 @@
 #include <unistd.h>
 
 #include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index aa410ab..6311d60 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -20,8 +20,9 @@
 #include <sys/types.h>
 
 #include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
 #include "tensorflow/contrib/lite/model.h"
 #ifndef TFLITE_MCU
 #include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -42,41 +43,6 @@
 
 const char* kEmptyTensorName = "";
 
-TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
-                               ErrorReporter* error_reporter) {
-  switch (tensor_type) {
-    case TensorType_FLOAT32:
-      *type = kTfLiteFloat32;
-      break;
-    case TensorType_INT16:
-      *type = kTfLiteInt16;
-      break;
-    case TensorType_INT32:
-      *type = kTfLiteInt32;
-      break;
-    case TensorType_UINT8:
-      *type = kTfLiteUInt8;
-      break;
-    case TensorType_INT64:
-      *type = kTfLiteInt64;
-      break;
-    case TensorType_STRING:
-      *type = kTfLiteString;
-      break;
-    case TensorType_BOOL:
-      *type = kTfLiteBool;
-      break;
-    case TensorType_COMPLEX64:
-      *type = kTfLiteComplex64;
-      break;
-    default:
-      error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
-                             EnumNameTensorType(tensor_type), tensor_type);
-      return kTfLiteError;
-  }
-  return kTfLiteOk;
-}
-
 #ifndef TFLITE_MCU
 // Loads a model from `filename`. If `mmap_file` is true then use mmap,
 // otherwise make a copy of the model in a buffer.
@@ -198,39 +164,10 @@
   auto opcodes = model_->operator_codes();
   for (const OperatorCode* opcode : *opcodes) {
     const TfLiteRegistration* registration = nullptr;
-    auto builtin_code = opcode->builtin_code();
-    int version = opcode->version();
-
-    if (builtin_code > BuiltinOperator_MAX ||
-        builtin_code < BuiltinOperator_MIN) {
-      error_reporter_->Report(
-          "Op builtin_code out or range: %d. Are you using old TFLite binary "
-          "with newer model?",
-          builtin_code);
-      status = kTfLiteError;
-    } else if (builtin_code != BuiltinOperator_CUSTOM) {
-      registration = op_resolver_.FindOp(builtin_code, version);
-      if (registration == nullptr) {
-        error_reporter_->Report(
-            "Didn't find op for builtin opcode '%s' version '%d'\n",
-            EnumNameBuiltinOperator(builtin_code), version);
-        status = kTfLiteError;
-      }
-    } else if (!opcode->custom_code()) {
-      error_reporter_->Report(
-          "Operator with CUSTOM builtin_code has no custom_code.\n");
-      status = kTfLiteError;
-    } else {
-      const char* name = opcode->custom_code()->c_str();
-      registration = op_resolver_.FindOp(name, version);
-      flatbuffer_op_index_to_registration_types_.push_back(
-          BuiltinOperator_CUSTOM);
-      if (registration == nullptr) {
-        error_reporter_->Report(
-            "Didn't find custom op for name '%s' with version %d\n", name,
-            version);
-        status = kTfLiteError;
-      }
+    status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+                                       &registration);
+    if (status != kTfLiteOk) {
+      return status;
     }
     flatbuffer_op_index_to_registration_.push_back(registration);
   }
@@ -240,6 +177,11 @@
 namespace {
 template <class T>
 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
+  // Initialize shape of tensors with null shape. Empty vectors are converted
+  // to nullptr for models that are constructed via flatbuffers::Pack.
+  if (flat_array == nullptr) {
+    return {};
+  }
   std::vector<int> ret(flat_array->Length());
   for (int i = 0; i < flat_array->Length(); i++) {
     ret[i] = flat_array->Get(i);
@@ -247,565 +189,6 @@
   return ret;
 }
 
-// Copies the contents from the flatbuffer int vector `flatbuffer` into the
-// int array `buffer`. `flat_vector` and `buffer` represent the same
-// configuration operation for a given operation.
-void FlatBufferIntVectorToArray(int max_size_of_buffer,
-                                const flatbuffers::Vector<int32_t>* flat_vector,
-                                int* buffer, ErrorReporter* error_reporter) {
-  if (!flat_vector) {
-    error_reporter->Report("Input array not provided for operation.\n");
-  } else {
-    int num_dimensions = flat_vector->Length();
-    if (num_dimensions > max_size_of_buffer / sizeof(int)) {
-      error_reporter->Report(
-          "Found too many dimensions in the operation's input array.\n");
-    } else {
-      for (int i = 0; i < num_dimensions; ++i) {
-        buffer[i] = flat_vector->Get(i);
-      }
-    }
-  }
-}
-
-// Allocate a structure using C malloc, but make sure the structure is a
-// POD structure that doesn't require constructors to run. The reason we do
-// this, is that Interpreter's C extension part will take ownership and wants
-// to use malloc() and free().
-template <class T>
-T* MallocPOD() {
-  static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
-  return static_cast<T*>(malloc(sizeof(T)));
-}
-
-// Parse the appropriate data out of the op.
-//
-// This handles builtin data explicitly as there are flatbuffer schemas.
-// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
-// need to be released by calling `free`.`
-// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
-TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
-                         ErrorReporter* error_reporter, void** builtin_data) {
-  auto parse_padding = [](Padding padding) {
-    switch (padding) {
-      case Padding_SAME:
-        return kTfLitePaddingSame;
-      case Padding_VALID:
-        return kTfLitePaddingValid;
-    }
-    return kTfLitePaddingUnknown;
-  };
-  auto parse_activation = [](ActivationFunctionType activation) {
-    switch (activation) {
-      case ActivationFunctionType_NONE:
-        return kTfLiteActNone;
-      case ActivationFunctionType_RELU:
-        return kTfLiteActRelu;
-      case ActivationFunctionType_RELU_N1_TO_1:
-        return kTfLiteActRelu1;
-      case ActivationFunctionType_RELU6:
-        return kTfLiteActRelu6;
-      case ActivationFunctionType_TANH:
-        return kTfLiteActTanh;
-      case ActivationFunctionType_SIGN_BIT:
-        return kTfLiteActSignBit;
-    }
-    return kTfLiteActNone;
-  };
-  auto parseLSHProjectionType = [](LSHProjectionType type) {
-    switch (type) {
-      case LSHProjectionType_SPARSE:
-        return kTfLiteLshProjectionSparse;
-      case LSHProjectionType_DENSE:
-        return kTfLiteLshProjectionDense;
-      default:
-        return kTfLiteLshProjectionUnknown;
-    }
-  };
-  auto parseCombinerType = [](CombinerType type) {
-    switch (type) {
-      case CombinerType_MEAN:
-        return kTfLiteCombinerTypeMean;
-      case CombinerType_SQRTN:
-        return kTfLiteCombinerTypeSqrtn;
-      case CombinerType_SUM:
-      default:
-        return kTfLiteCombinerTypeSum;
-    }
-  };
-
-  *builtin_data = nullptr;
-  switch (op_type) {
-    case BuiltinOperator_CONV_2D: {
-      TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
-      if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
-        params->padding = parse_padding(conv_params->padding());
-        params->stride_width = conv_params->stride_w();
-        params->stride_height = conv_params->stride_h();
-        params->activation =
-            parse_activation(conv_params->fused_activation_function());
-
-        params->dilation_width_factor = conv_params->dilation_w_factor();
-        params->dilation_height_factor = conv_params->dilation_h_factor();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_CAST: {
-      TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
-      if (auto* schema_params = op->builtin_options_as_CastOptions()) {
-        auto in_status =
-            ConvertTensorType(schema_params->in_data_type(),
-                              &params->in_data_type, error_reporter);
-        auto out_status =
-            ConvertTensorType(schema_params->out_data_type(),
-                              &params->out_data_type, error_reporter);
-        if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
-          free(params);
-          return kTfLiteError;
-        }
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_LSH_PROJECTION: {
-      TfLiteLSHProjectionParams* params =
-          MallocPOD<TfLiteLSHProjectionParams>();
-      if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
-        params->type = parseLSHProjectionType(lshParams->type());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_AVERAGE_POOL_2D:
-    case BuiltinOperator_MAX_POOL_2D:
-    case BuiltinOperator_L2_POOL_2D: {
-      TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
-      if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
-        params->padding = parse_padding(pool_params->padding());
-        params->stride_width = pool_params->stride_w();
-        params->stride_height = pool_params->stride_h();
-        params->filter_width = pool_params->filter_width();
-        params->filter_height = pool_params->filter_height();
-        params->activation =
-            parse_activation(pool_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_DEPTHWISE_CONV_2D: {
-      TfLiteDepthwiseConvParams* params =
-          MallocPOD<TfLiteDepthwiseConvParams>();
-      if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
-        params->padding = parse_padding(conv_params->padding());
-        params->stride_width = conv_params->stride_w();
-        params->stride_height = conv_params->stride_h();
-        params->depth_multiplier = conv_params->depth_multiplier();
-        params->activation =
-            parse_activation(conv_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SVDF: {
-      TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
-      if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
-        params->rank = svdf_params->rank();
-        params->activation =
-            parse_activation(svdf_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
-    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
-      TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
-      if (auto* sequence_rnn_params =
-              op->builtin_options_as_SequenceRNNOptions()) {
-        params->activation =
-            parse_activation(sequence_rnn_params->fused_activation_function());
-        params->time_major = sequence_rnn_params->time_major();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_RNN: {
-      TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
-      if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
-        params->activation =
-            parse_activation(rnn_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
-      TfLiteEmbeddingLookupSparseParams* params =
-          MallocPOD<TfLiteEmbeddingLookupSparseParams>();
-      if (auto* embedding_params =
-              op->builtin_options_as_EmbeddingLookupSparseOptions()) {
-        params->combiner = parseCombinerType(embedding_params->combiner());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_FULLY_CONNECTED: {
-      TfLiteFullyConnectedParams* params =
-          MallocPOD<TfLiteFullyConnectedParams>();
-      if (auto* fully_connected_params =
-              op->builtin_options_as_FullyConnectedOptions()) {
-        params->activation = parse_activation(
-            fully_connected_params->fused_activation_function());
-        switch (fully_connected_params->weights_format()) {
-          case FullyConnectedOptionsWeightsFormat_DEFAULT:
-            params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
-            break;
-          case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
-            params->weights_format =
-                kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
-            break;
-          default:
-            error_reporter->Report("Unhandled fully-connected weights format.");
-            return kTfLiteError;
-        }
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_HASHTABLE_LOOKUP:
-      // no-op.
-      break;
-    case BuiltinOperator_SOFTMAX: {
-      TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
-      if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
-        params->beta = softmax_params->beta();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_CONCATENATION: {
-      TfLiteConcatenationParams* params =
-          MallocPOD<TfLiteConcatenationParams>();
-      if (auto* concatenation_params =
-              op->builtin_options_as_ConcatenationOptions()) {
-        params->activation =
-            parse_activation(concatenation_params->fused_activation_function());
-        params->axis = concatenation_params->axis();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_MUL: {
-      auto* params = MallocPOD<TfLiteMulParams>();
-      if (auto* schema_params = op->builtin_options_as_MulOptions()) {
-        params->activation =
-            parse_activation(schema_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_ADD: {
-      auto* params = MallocPOD<TfLiteAddParams>();
-      if (auto* schema_params = op->builtin_options_as_AddOptions()) {
-        params->activation =
-            parse_activation(schema_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_DIV: {
-      auto* params = MallocPOD<TfLiteDivParams>();
-      if (auto* schema_params = op->builtin_options_as_DivOptions()) {
-        params->activation =
-            parse_activation(schema_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SUB: {
-      auto* params = MallocPOD<TfLiteSubParams>();
-      if (auto* schema_params = op->builtin_options_as_SubOptions()) {
-        params->activation =
-            parse_activation(schema_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_L2_NORMALIZATION: {
-      auto* params = MallocPOD<TfLiteL2NormParams>();
-      if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
-        params->activation =
-            parse_activation(schema_params->fused_activation_function());
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
-      auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
-      if (auto* schema_params =
-              op->builtin_options_as_LocalResponseNormalizationOptions()) {
-        params->radius = schema_params->radius();
-        params->bias = schema_params->bias();
-        params->alpha = schema_params->alpha();
-        params->beta = schema_params->beta();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
-    case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
-    case BuiltinOperator_LSTM: {
-      TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
-      if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
-        params->activation =
-            parse_activation(lstm_params->fused_activation_function());
-        params->cell_clip = lstm_params->cell_clip();
-        params->proj_clip = lstm_params->proj_clip();
-        switch (lstm_params->kernel_type()) {
-          case LSTMKernelType_FULL:
-            params->kernel_type = kTfLiteLSTMFullKernel;
-            break;
-          case LSTMKernelType_BASIC:
-            params->kernel_type = kTfLiteLSTMBasicKernel;
-            break;
-        }
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_RESIZE_BILINEAR: {
-      auto* params = MallocPOD<TfLiteResizeBilinearParams>();
-      if (auto* schema_params =
-              op->builtin_options_as_ResizeBilinearOptions()) {
-        params->align_corners = schema_params->align_corners();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_RESHAPE: {
-      auto* params = MallocPOD<TfLiteReshapeParams>();
-      if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
-        auto* new_shape = schema_params->new_shape();
-        FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
-                                   params->shape, error_reporter);
-        params->num_dimensions = new_shape->Length();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SKIP_GRAM: {
-      TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
-      if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
-        params->ngram_size = skip_gram_params->ngram_size();
-        params->max_skip_size = skip_gram_params->max_skip_size();
-        params->include_all_ngrams = skip_gram_params->include_all_ngrams();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SPACE_TO_DEPTH: {
-      auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
-      if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
-        params->block_size = schema_params->block_size();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_GATHER: {
-      TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
-      params->axis = 0;
-      if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
-        params->axis = gather_params->axis();
-      }
-
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_MEAN:
-    case BuiltinOperator_REDUCE_MAX:
-    case BuiltinOperator_REDUCE_MIN:
-    case BuiltinOperator_REDUCE_PROD:
-    case BuiltinOperator_SUM:
-    case BuiltinOperator_REDUCE_ANY: {
-      auto* params = MallocPOD<TfLiteReducerParams>();
-      if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
-        params->keep_dims = schema_params->keep_dims();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SPLIT: {
-      auto* params = MallocPOD<TfLiteSplitParams>();
-      if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
-        params->num_splits = schema_params->num_splits();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SQUEEZE: {
-      auto* params = MallocPOD<TfLiteSqueezeParams>();
-      if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
-        const auto& squeeze_dims = schema_params->squeeze_dims();
-        FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
-                                   params->squeeze_dims, error_reporter);
-        params->num_squeeze_dims = squeeze_dims->Length();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_STRIDED_SLICE: {
-      auto* params = MallocPOD<TfLiteStridedSliceParams>();
-      if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
-        params->begin_mask = schema_params->begin_mask();
-        params->end_mask = schema_params->end_mask();
-        params->ellipsis_mask = schema_params->ellipsis_mask();
-        params->new_axis_mask = schema_params->new_axis_mask();
-        params->shrink_axis_mask = schema_params->shrink_axis_mask();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_ARG_MAX: {
-      auto* params = MallocPOD<TfLiteArgMaxParams>();
-      if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
-        ConvertTensorType(schema_params->output_type(), &params->output_type,
-                          error_reporter);
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_ARG_MIN: {
-      auto* params = MallocPOD<TfLiteArgMinParams>();
-      if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
-        ConvertTensorType(schema_params->output_type(), &params->output_type,
-                          error_reporter);
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_TRANSPOSE_CONV: {
-      TfLiteTransposeConvParams* params =
-          MallocPOD<TfLiteTransposeConvParams>();
-      if (auto* transpose_conv_params =
-              op->builtin_options_as_TransposeConvOptions()) {
-        params->padding = parse_padding(transpose_conv_params->padding());
-        params->stride_width = transpose_conv_params->stride_w();
-        params->stride_height = transpose_conv_params->stride_h();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SPARSE_TO_DENSE: {
-      TfLiteSparseToDenseParams* params =
-          MallocPOD<TfLiteSparseToDenseParams>();
-      if (auto* sparse_to_dense_params =
-              op->builtin_options_as_SparseToDenseOptions()) {
-        params->validate_indices = sparse_to_dense_params->validate_indices();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_SHAPE: {
-      auto* params = MallocPOD<TfLiteShapeParams>();
-      if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
-        ConvertTensorType(schema_params->out_type(), &params->out_type,
-                          error_reporter);
-      }
-      *builtin_data = static_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_PACK: {
-      TfLitePackParams* params = MallocPOD<TfLitePackParams>();
-      if (auto* pack_params = op->builtin_options_as_PackOptions()) {
-        params->values_count = pack_params->values_count();
-        params->axis = pack_params->axis();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_DELEGATE: {
-      // TODO(ycling): Revisit when supporting saving delegated models.
-      error_reporter->Report("DELEGATE op shouldn't exist in model.");
-      return kTfLiteError;
-    }
-    case BuiltinOperator_FAKE_QUANT: {
-      auto* params = MallocPOD<TfLiteFakeQuantParams>();
-      if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
-        params->min = schema_params->min();
-        params->max = schema_params->max();
-        params->num_bits = schema_params->num_bits();
-        params->narrow_range = schema_params->narrow_range();
-      }
-      *builtin_data = static_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_ONE_HOT: {
-      auto* params = MallocPOD<TfLiteOneHotParams>();
-      if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
-        params->axis = schema_params->axis();
-      }
-      *builtin_data = static_cast<void*>(params);
-      break;
-    }
-    case BuiltinOperator_UNPACK: {
-      TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
-      if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
-        params->num = unpack_params->num();
-        params->axis = unpack_params->axis();
-      }
-      *builtin_data = reinterpret_cast<void*>(params);
-      break;
-    }
-
-    // Below are the ops with no builtin_data strcture.
-    case BuiltinOperator_BATCH_TO_SPACE_ND:
-    // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
-    // ok for now, since there is no call implementation either.
-    case BuiltinOperator_CALL:
-    case BuiltinOperator_CONCAT_EMBEDDINGS:
-    case BuiltinOperator_CUSTOM:
-    case BuiltinOperator_DEQUANTIZE:
-    case BuiltinOperator_EMBEDDING_LOOKUP:
-    case BuiltinOperator_EQUAL:
-    case BuiltinOperator_EXP:
-    case BuiltinOperator_EXPAND_DIMS:
-    case BuiltinOperator_FLOOR:
-    case BuiltinOperator_GREATER:
-    case BuiltinOperator_GREATER_EQUAL:
-    case BuiltinOperator_LESS:
-    case BuiltinOperator_LESS_EQUAL:
-    case BuiltinOperator_LOG:
-    case BuiltinOperator_LOGISTIC:
-    case BuiltinOperator_LOG_SOFTMAX:
-    case BuiltinOperator_MAXIMUM:
-    case BuiltinOperator_MINIMUM:
-    case BuiltinOperator_NEG:
-    case BuiltinOperator_NOT_EQUAL:
-    case BuiltinOperator_PAD:
-    case BuiltinOperator_PADV2:
-    case BuiltinOperator_PRELU:
-    case BuiltinOperator_RELU:
-    case BuiltinOperator_RELU6:
-    case BuiltinOperator_RELU_N1_TO_1:
-    case BuiltinOperator_RSQRT:
-    case BuiltinOperator_SELECT:
-    case BuiltinOperator_SIN:
-    case BuiltinOperator_SLICE:
-    case BuiltinOperator_SPACE_TO_BATCH_ND:
-    case BuiltinOperator_SQRT:
-    case BuiltinOperator_TANH:
-    case BuiltinOperator_TILE:
-    case BuiltinOperator_TOPK_V2:
-    case BuiltinOperator_TRANSPOSE:
-    case BuiltinOperator_POW:
-    case BuiltinOperator_LOGICAL_OR:
-    case BuiltinOperator_LOGICAL_AND:
-    case BuiltinOperator_LOGICAL_NOT:
-    case BuiltinOperator_FLOOR_DIV:
-      break;
-  }
-  return kTfLiteOk;
-}
-
 }  // namespace
 
 TfLiteStatus InterpreterBuilder::ParseNodes(
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8bc9ecd..6abdfcd 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -35,9 +35,10 @@
 #define TENSORFLOW_CONTRIB_LITE_MODEL_H_
 
 #include <memory>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
 #include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
 #include "tensorflow/contrib/lite/schema/schema_generated.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index df4f60d..ec7d46a 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -23,7 +23,7 @@
 #include "tensorflow/contrib/lite/model.h"
 
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/testing/util.h"
 
 // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
similarity index 95%
rename from tensorflow/contrib/lite/op_resolver.cc
rename to tensorflow/contrib/lite/mutable_op_resolver.cc
index f6e435e..8ee63d2 100644
--- a/tensorflow/contrib/lite/op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -13,8 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000..c319041
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+  size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+  size_t operator()(const tflite::BuiltinOperator& v) const {
+    return std::hash<int>()(static_cast<int>(v));
+  }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+  size_t operator()(const T& x) const {
+    size_t a = ValueHasher<typename T::first_type>()(x.first);
+    size_t b = ValueHasher<typename T::second_type>()(x.second);
+    return CombineHashes({a, b});
+  }
+};
+}  // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+//   MutableOpResolver resolver;
+//   resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+//   resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+//   InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+                                   int version) const override;
+  const TfLiteRegistration* FindOp(const char* op, int version) const override;
+  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+                  int min_version = 1, int max_version = 1);
+  void AddCustom(const char* name, TfLiteRegistration* registration,
+                 int min_version = 1, int max_version = 1);
+
+ private:
+  typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+  typedef std::pair<std::string, int> CustomOperatorKey;
+
+  std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+                     op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+      builtins_;
+  std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+                     op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+      custom_ops_;
+};
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
similarity index 98%
rename from tensorflow/contrib/lite/op_resolver_test.cc
rename to tensorflow/contrib/lite/mutable_op_resolver_test.cc
index 10b7e31..db690ea 100644
--- a/tensorflow/contrib/lite/op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -13,7 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
 
 #include <gtest/gtest.h>
 #include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 81dd459..6879440 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -364,6 +364,9 @@
     ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
     uint32_t outputCount, const uint32_t* outputs);
 
+typedef int (*ANeuralNetworksModel_relaxComputationFloat32toFloat16_fn)(
+    ANeuralNetworksModel* model, bool allow);
+
 typedef int (*ANeuralNetworksExecution_create_fn)(
     ANeuralNetworksCompilation* compilation,
     ANeuralNetworksExecution** execution);
@@ -656,6 +659,34 @@
 }
 
 /**
+ * Specifies whether {@link ANEURALNETWORKS_TENSOR_FLOAT32} is allowed to be
+ * calculated with range and/or precision as low as that of the IEEE 754 16-bit
+ * floating-point format. By default, {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * must be calculated using at least the range and precision of the IEEE 754
+ * 32-bit floating-point format.
+ *
+ * @param model The model to be modified.
+ * @param allow 'true' indicates {@link ANEURALNETWORKS_TENSOR_FLOAT32} may be
+ *              calculated with range and/or precision as low as that of the
+ *              IEEE 754 16-bit floating point format. 'false' indicates
+ *              {@link ANEURALNETWORKS_TENSOR_FLOAT32} must be calculated using
+ *              at least the range and precision of the IEEE 754 32-bit floating
+ *              point format.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * Available since API level 28.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ */
+inline int ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+    ANeuralNetworksModel* model, bool allow) {
+  LOAD_FUNCTION(ANeuralNetworksModel_relaxComputationFloat32toFloat16);
+  EXECUTE_FUNCTION_RETURN(model, allow);
+}
+
+/**
  * Create a {@link ANeuralNetworksCompilation} to compile the given model.
  * This only creates the object. Compilation is only performed once
  * {@link ANeuralNetworksCompilation_start} is invoked.
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 2b6caf1..0656884 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -19,8 +19,8 @@
 #include <sys/stat.h>
 #include <sys/types.h>
 #include <unordered_set>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/model.h"
 #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
 
@@ -65,6 +65,14 @@
           __LINE__);                                                    \
   }
 
+#define RETURN_ERROR_IF_TFLITE_FAILED(x)                                       \
+  if (x != kTfLiteOk) {                                                        \
+    logError(                                                                  \
+        "Returning error since TFLite returned failure nnapi_delegate.cc:%d.", \
+        __LINE__);                                                             \
+    return kTfLiteError;                                                       \
+  }
+
 #define RETURN_ERROR_IF_NN_FAILED(x)                                          \
   if (x != ANEURALNETWORKS_NO_ERROR) {                                        \
     logError(                                                                 \
@@ -303,17 +311,21 @@
         };
     auto check_and_add_activation = [&add_scalar_int32](int activation) {
       if (activation > kTfLiteActRelu6) {
-        FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+        logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+        return kTfLiteError;
       }
       add_scalar_int32(activation);
+      return kTfLiteOk;
     };
 
     auto add_add_params = [&add_scalar_int32](void* data) {
       auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
       if (builtin->activation > kTfLiteActRelu6) {
-        FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+        logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+        return kTfLiteError;
       }
       add_scalar_int32(builtin->activation);
+      return kTfLiteOk;
     };
 
     auto add_pooling_params = [&add_scalar_int32,
@@ -324,7 +336,7 @@
       add_scalar_int32(builtin->stride_height);
       add_scalar_int32(builtin->filter_width);
       add_scalar_int32(builtin->filter_height);
-      check_and_add_activation(builtin->activation);
+      return check_and_add_activation(builtin->activation);
     };
 
     auto add_convolution_params = [&add_scalar_int32,
@@ -333,7 +345,7 @@
       add_scalar_int32(builtin->padding);
       add_scalar_int32(builtin->stride_width);
       add_scalar_int32(builtin->stride_height);
-      check_and_add_activation(builtin->activation);
+      return check_and_add_activation(builtin->activation);
     };
 
     auto add_depthwise_conv_params = [&add_scalar_int32,
@@ -343,20 +355,22 @@
       add_scalar_int32(builtin->stride_width);
       add_scalar_int32(builtin->stride_height);
       add_scalar_int32(builtin->depth_multiplier);
-      check_and_add_activation(builtin->activation);
+      return check_and_add_activation(builtin->activation);
     };
 
     auto add_fully_connected_params = [&check_and_add_activation](void* data) {
       auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
-      check_and_add_activation(builtin->activation);
+      return check_and_add_activation(builtin->activation);
     };
 
     auto add_concatenation_params = [&add_scalar_int32](void* data) {
       auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
       add_scalar_int32(builtin->axis);
       if (builtin->activation != kTfLiteActNone) {
-        FATAL("Concatenation does not support fused activation in NNAPI");
+        logError("Concatenation does not support fused activation in NNAPI");
+        return kTfLiteError;
       }
+      return kTfLiteOk;
     };
 
     auto add_softmax_params = [&add_scalar_float32](void* data) {
@@ -437,22 +451,22 @@
     switch (builtin) {
       case tflite::BuiltinOperator_ADD:
         nn_op_type = ANEURALNETWORKS_ADD;
-        add_add_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
         break;
       case tflite::BuiltinOperator_MUL:
         nn_op_type = ANEURALNETWORKS_MUL;
-        add_add_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
         break;
       case tflite::BuiltinOperator_AVERAGE_POOL_2D:
-        add_pooling_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
         break;
       case tflite::BuiltinOperator_MAX_POOL_2D:
-        add_pooling_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
         break;
       case tflite::BuiltinOperator_L2_POOL_2D:
-        add_pooling_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
         break;
       case tflite::BuiltinOperator_CONV_2D: {
@@ -463,7 +477,8 @@
           return kTfLiteError;
         }
       }
-        add_convolution_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(
+            add_convolution_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_CONV_2D;
         break;
       case tflite::BuiltinOperator_RELU:
@@ -482,11 +497,13 @@
         nn_op_type = ANEURALNETWORKS_LOGISTIC;
         break;
       case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
-        add_depthwise_conv_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(
+            add_depthwise_conv_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
         break;
       case tflite::BuiltinOperator_CONCATENATION:
-        add_concatenation_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(
+            add_concatenation_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_CONCATENATION;
         break;
       case tflite::BuiltinOperator_SOFTMAX:
@@ -494,7 +511,8 @@
         nn_op_type = ANEURALNETWORKS_SOFTMAX;
         break;
       case tflite::BuiltinOperator_FULLY_CONNECTED:
-        add_fully_connected_params(node.builtin_data);
+        RETURN_ERROR_IF_TFLITE_FAILED(
+            add_fully_connected_params(node.builtin_data));
         nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
         break;
       case tflite::BuiltinOperator_RESHAPE:
@@ -548,14 +566,14 @@
       case tflite::BuiltinOperator_DIV:
         nnapi_version = 11;  // require NNAPI 1.1
         nn_op_type = ANEURALNETWORKS_DIV;
-        check_and_add_activation(
-            reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation);
+        RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+            reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation));
         break;
       case tflite::BuiltinOperator_SUB:
         nnapi_version = 11;  // require NNAPI 1.1
         nn_op_type = ANEURALNETWORKS_SUB;
-        check_and_add_activation(
-            reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation);
+        RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+            reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation));
         break;
       case tflite::BuiltinOperator_SQUEEZE:
         nnapi_version = 11;  // requires NNAPI 1.1
@@ -658,6 +676,7 @@
       case tflite::BuiltinOperator_UNPACK:
       case tflite::BuiltinOperator_FLOOR_DIV:
       case tflite::BuiltinOperator_REDUCE_ANY:
+      case tflite::BuiltinOperator_SQUARE:
         logError("Op code %d is currently not delegated to NNAPI", builtin);
         return kTfLiteError;
         break;
@@ -668,7 +687,8 @@
     }
 
     if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) {
-      FATAL("Op %d needs NNAPI1.1", builtin);
+      logError("Op %d needs NNAPI1.1", builtin);
+      return kTfLiteError;
     }
 
     // Add the operation.
@@ -716,9 +736,9 @@
                        interpreter->outputs().size());
 
     uint32_t next_id = 0;
-    RETURN_ERROR_IF_NN_FAILED(addTensorOperands(
+    RETURN_ERROR_IF_TFLITE_FAILED(addTensorOperands(
         interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id));
-    RETURN_ERROR_IF_NN_FAILED(
+    RETURN_ERROR_IF_TFLITE_FAILED(
         AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
                         &model_states_outputs_, tensor_id_to_nnapi_id));
 
@@ -742,6 +762,11 @@
         reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
         static_cast<uint32_t>(augmented_outputs.size()),
         reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
+
+    if (GetAndroidSdkVersionCached() >= 28) {
+      CHECK_NN(ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+          nn_model_, interpreter->GetAllowFp16PrecisionForFp32()));
+    }
     CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
   }
   if (!nn_compiled_model_) {
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 2bdb2cc..22359d5 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -16,8 +16,8 @@
 #define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
 
 #include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 
 class ANeuralNetworksModel;
diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
index efde72b..e3536d3 100644
--- a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
@@ -27,7 +27,13 @@
 
 NNAPIAllocation::~NNAPIAllocation() {}
 
-NNAPIDelegate::~NNAPIDelegate() {}
+NNAPIDelegate::~NNAPIDelegate() {
+#define UNUSED_MEMBER(x) (void)(x)
+  UNUSED_MEMBER(nn_model_);
+  UNUSED_MEMBER(nn_compiled_model_);
+  UNUSED_MEMBER(model_status_);
+#undef UNUSED_MEMBER
+}
 
 TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
   return kTfLiteError;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 9d7e3f2..e93134c 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -12,83 +12,11 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// Compatibility shim for moved header location.
 #ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
 #define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
 
-#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/schema/schema_generated.h"
-#include "tensorflow/contrib/lite/util.h"
-
-namespace tflite {
-
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
-  // Finds the op registration for a builtin operator by enum code.
-  virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
-                                           int version) const = 0;
-  // Finds the op registration of a custom operator by op name.
-  virtual const TfLiteRegistration* FindOp(const char* op,
-                                           int version) const = 0;
-  virtual ~OpResolver() {}
-};
-
-// Some versions of gcc doesn't support partial specialization in class scope,
-// so these are defined in a namescope.
-namespace op_resolver_hasher {
-template <typename V>
-struct ValueHasher {
-  size_t operator()(const V& v) const { return std::hash<V>()(v); }
-};
-
-template <>
-struct ValueHasher<tflite::BuiltinOperator> {
-  size_t operator()(const tflite::BuiltinOperator& v) const {
-    return std::hash<int>()(static_cast<int>(v));
-  }
-};
-
-template <typename T>
-struct OperatorKeyHasher {
-  size_t operator()(const T& x) const {
-    size_t a = ValueHasher<typename T::first_type>()(x.first);
-    size_t b = ValueHasher<typename T::second_type>()(x.second);
-    return CombineHashes({a, b});
-  }
-};
-}  // namespace op_resolver_hasher
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-//   MutableOpResolver resolver;
-//   resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-//   resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-//   InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
-  const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
-                                   int version) const override;
-  const TfLiteRegistration* FindOp(const char* op, int version) const override;
-  void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
-                  int min_version = 1, int max_version = 1);
-  void AddCustom(const char* name, TfLiteRegistration* registration,
-                 int min_version = 1, int max_version = 1);
-
- private:
-  typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
-  typedef std::pair<std::string, int> CustomOperatorKey;
-
-  std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
-                     op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
-      builtins_;
-  std::unordered_map<CustomOperatorKey, TfLiteRegistration,
-                     op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
-      custom_ops_;
-};
-
-}  // namespace tflite
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
 
 #endif  // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 6e30251..57e1290 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@
 py_test(
     name = "lite_test",
     srcs = ["lite_test.py"],
-    data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"],
+    data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
     srcs_version = "PY2AND3",
     tags = [
         "no_oss",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1c5516a..1f48a82 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -18,6 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import enum  # pylint: disable=g-bad-import-order
+
 import os as _os
 import platform as _platform
 import subprocess as _subprocess
@@ -30,7 +32,6 @@
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.lazy_loader import LazyLoader
 
-
 # Lazy load since some of the performance benchmark skylark rules
 # break dependencies.
 _toco_python = LazyLoader(
@@ -52,6 +53,31 @@
   _toco_from_proto_bin = "toco_from_protos"
 
 
+class ConverterMode(enum.Enum):
+  """Enum class defining the converters available to generate TFLite models.
+
+  WARNING: Experimental interface, subject to change.
+  """
+  # Convert model using TOCO such that all ops are TensorFlow Lite native ops.
+  #
+  # This is the only supported mode for any models that contain operations that
+  # cannot be resolved in TensorFlow.
+  DEFAULT = "DEFAULT"
+
+  # Convert model using TOCO such that only unsupported operations are
+  # represented as TensorFlow ops.
+  # WARNING: Experimental interface, subject to change.
+  TOCO_EXTENDED = "TOCO_EXTENDED"
+
+  # Convert model using TOCO such that all operations are represented as
+  # TensorFlow ops.
+  # WARNING: Experimental interface, subject to change.
+  TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+
+  def __str__(self):
+    return self.value
+
+
 def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
   """Convert `input_data_str` according to model and toco parameters.
 
@@ -128,7 +154,8 @@
                               change_concat_input_ranges=False,
                               post_training_quantize=False,
                               dump_graphviz_dir=None,
-                              dump_graphviz_video=False):
+                              dump_graphviz_video=False,
+                              converter_mode=ConverterMode.DEFAULT):
   """Builds protocol buffers describing a conversion of a model using TOCO.
 
   Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -183,6 +210,8 @@
       output file. (default None)
     dump_graphviz_video: Boolean indicating whether to dump the graph after
       every graph transformation. (default False)
+    converter_mode: Experimental flag, subject to change. ConverterMode
+      indicating which converter to use. (default ConverterMode.DEFAULT)
 
   Returns:
     model_flags, toco_flags: two protocol buffers describing the conversion
@@ -211,6 +240,11 @@
   if dump_graphviz_dir:
     toco.dump_graphviz_dir = dump_graphviz_dir
   toco.dump_graphviz_include_video = dump_graphviz_video
+  if converter_mode == ConverterMode.TOCO_EXTENDED:
+    toco.allow_eager_ops = True
+  elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
+    toco.allow_eager_ops = True
+    toco.force_eager_ops = True
 
   model = _model_flags_pb2.ModelFlags()
   model.change_concat_input_ranges = change_concat_input_ranges
@@ -301,9 +335,8 @@
   Raises:
     Defined in `build_toco_convert_protos`.
   """
-  model_flags, toco_flags = build_toco_convert_protos(input_tensors,
-                                                      output_tensors,
-                                                      *args, **kwargs)
+  model_flags, toco_flags = build_toco_convert_protos(
+      input_tensors, output_tensors, *args, **kwargs)
   data = toco_convert_protos(model_flags.SerializeToString(),
                              toco_flags.SerializeToString(),
                              input_data.SerializeToString())
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index 59f537b..40a8b5f 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -188,7 +188,7 @@
       return output
     output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # check if identities have been put into the graph (2 input, 1 output,
       # and 1 final output).
       self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
@@ -215,7 +215,7 @@
     output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
                                 name="ModelOutput")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
       # +1 for the final output
       self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
@@ -242,7 +242,7 @@
     output = array_ops.identity(
         math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # make sure one identity for each input (2) and output (2) => 2 + 2
       # +1 for the final output
       self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
@@ -279,7 +279,7 @@
                          aggregate=op_hint.OpHint.AGGREGATE_STACK)
     res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
     custom.add_outputs([res])
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(self._get_input_index(a), 0)
       self.assertEqual(self._get_sort_index(a), 0)
       self.assertEqual(self._get_input_index(b), 1)
@@ -294,7 +294,7 @@
     b = custom.add_input(b)  # should auto assign 0
     a = custom.add_input(a, index_override=1)
     c = custom.add_input(c)  # should auto assign 2
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(self._get_input_index(a), 1)
       self.assertEqual(self._get_input_index(b), 0)
       self.assertEqual(self._get_input_index(c), 2)
@@ -320,7 +320,7 @@
 
     curr = array_ops.stack([c0, c1])
     output = array_ops.identity(curr, name="FINAL_OUTPUT")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
           graph_def=sess.graph_def)
       self.assertCountEqual(
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2de97fe..2be2445 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,6 +40,7 @@
 from google.protobuf.message import DecodeError
 from tensorflow.contrib.lite.python import lite_constants as constants
 from tensorflow.contrib.lite.python.convert import build_toco_convert_protos  # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import ConverterMode
 from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
 from tensorflow.contrib.lite.python.convert import toco_convert  # pylint: disable=unused-import
 from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
@@ -58,6 +59,7 @@
 from tensorflow.python.framework import ops as _ops
 from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
 from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
 from tensorflow.python.saved_model import signature_constants as _signature_constants
 from tensorflow.python.saved_model import tag_constants as _tag_constants
 
@@ -112,6 +114,8 @@
       output file. (default None)
     dump_graphviz_video: Boolean indicating whether to dump the graph after
       every graph transformation. (default False)
+    converter_mode: Experimental flag, subject to change. ConverterMode
+      indicating which converter to use. (default ConverterMode.DEFAULT)
 
   Example usage:
 
@@ -178,6 +182,7 @@
     self.post_training_quantize = False
     self.dump_graphviz_dir = None
     self.dump_graphviz_video = False
+    self.converter_mode = ConverterMode.DEFAULT
 
     # Attributes are used by models that cannot be loaded into TensorFlow.
     if not self._has_valid_tensors():
@@ -225,8 +230,10 @@
       TocoConverter class.
 
     Raises:
-      ValueError:
+      IOError:
+        File not found.
         Unable to parse input file.
+      ValueError:
         The graph is not frozen.
         input_arrays or output_arrays contains an invalid tensor name.
         input_shapes is not correctly defined when required
@@ -234,10 +241,13 @@
     with _ops.Graph().as_default():
       with _session.Session() as sess:
         # Read GraphDef from file.
-        graph_def = _graph_pb2.GraphDef()
-        with open(graph_def_file, "rb") as f:
+        if not _file_io.file_exists(graph_def_file):
+          raise IOError("File '{0}' does not exist.".format(graph_def_file))
+        with _file_io.FileIO(graph_def_file, "rb") as f:
           file_content = f.read()
+
         try:
+          graph_def = _graph_pb2.GraphDef()
           graph_def.ParseFromString(file_content)
         except (_text_format.ParseError, DecodeError):
           try:
@@ -248,9 +258,10 @@
                 file_content = file_content.decode("utf-8")
               else:
                 file_content = file_content.encode("utf-8")
+            graph_def = _graph_pb2.GraphDef()
             _text_format.Merge(file_content, graph_def)
           except (_text_format.ParseError, DecodeError):
-            raise ValueError(
+            raise IOError(
                 "Unable to parse input file '{}'.".format(graph_def_file))
 
         # Handles models with custom TFLite ops that cannot be resolved in
@@ -382,6 +393,7 @@
       ValueError:
         Input shape is not specified.
         None value for dimension in input_tensor.
+        ConverterMode option is unsupported for the model.
     """
     # Checks dimensions in input tensor.
     if self._has_valid_tensors():
@@ -432,12 +444,18 @@
 
     # Converts model.
     if self._has_valid_tensors():
+      converter_kwargs["converter_mode"] = self.converter_mode
       result = _toco_convert_impl(
           input_data=self._graph_def,
           input_tensors=self._input_tensors,
           output_tensors=self._output_tensors,
           **converter_kwargs)
     else:
+      # Graphs without valid tensors cannot be loaded into tf.Session since they
+      # contain TFLite operation(s) that cannot be resolved in TensorFlow.
+      if self.converter_mode != ConverterMode.DEFAULT:
+        raise ValueError("This model can only be converted with the default "
+                         "converter.")
       result = _toco_convert_graph_def(
           input_data=self._graph_def,
           input_arrays_with_shape=self._input_arrays_with_shape,
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 1c94ba6..f112ed5 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -402,6 +402,28 @@
     # Ensure that the quantized weights tflite model is smaller.
     self.assertTrue(len(quantized_tflite) < len(float_tflite))
 
+  def testExtendedMode(self):
+    in_tensor = array_ops.placeholder(
+        shape=[1, 16, 16, 3], dtype=dtypes.float32)
+    out_tensor = in_tensor + in_tensor
+    sess = session.Session()
+
+    # Convert model and ensure model is not None.
+    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+    converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+    tflite_model = converter.convert()
+    self.assertTrue(tflite_model)
+
+    # Ensures the model contains TensorFlow ops.
+    # TODO(nupurgarg): Check values once there is a Python delegate interface.
+    interpreter = Interpreter(model_content=tflite_model)
+    with self.assertRaises(RuntimeError) as error:
+      interpreter.allocate_tensors()
+    self.assertIn(
+        'Regular TensorFlow ops are not supported by this interpreter. Make '
+        'sure you invoke the Eager delegate before inference.',
+        str(error.exception))
+
 
 class FromFrozenGraphFile(test_util.TensorFlowTestCase):
 
@@ -521,14 +543,21 @@
     self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
     self.assertEqual((0., 0.), output_details[0]['quantization'])
 
-  def testInvalidFile(self):
+  def testInvalidFileNotFound(self):
+    with self.assertRaises(IOError) as error:
+      lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+                                           ['add'])
+    self.assertEqual('File \'invalid_file\' does not exist.',
+                     str(error.exception))
+
+  def testInvalidFileBadData(self):
     graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
     with gfile.Open(graph_def_file, 'wb') as temp_file:
       temp_file.write('bad data')
       temp_file.flush()
 
     # Attempts to convert the invalid model.
-    with self.assertRaises(ValueError) as error:
+    with self.assertRaises(IOError) as error:
       lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
                                            ['add'])
     self.assertEqual(
@@ -539,7 +568,7 @@
   def _initObjectDetectionArgs(self):
     # Initializes the arguments required for the object detection model.
     self._graph_def_file = resource_loader.get_path_to_datafile(
-        'testdata/tflite_graph.pbtxt')
+        'testdata/tflite_graph.pb')
     self._input_arrays = ['normalized_input_image_tensor']
     self._output_arrays = [
         'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
@@ -586,7 +615,7 @@
                      output_details[3]['name'])
     self.assertTrue(([1] == output_details[3]['shape']).all())
 
-  def testTFLiteGraphDefInvalid(self):
+  def testTFLiteGraphDefMissingShape(self):
     # Tests invalid cases for the model that cannot be loaded in TensorFlow.
     self._initObjectDetectionArgs()
 
@@ -597,6 +626,10 @@
     self.assertEqual('input_shapes must be defined for this model.',
                      str(error.exception))
 
+  def testTFLiteGraphDefInvalidShape(self):
+    # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+    self._initObjectDetectionArgs()
+
     # `input_shapes` does not contain the names in `input_arrays`.
     with self.assertRaises(ValueError) as error:
       lite.TocoConverter.from_frozen_graph(
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index cc08ed3..c0ff7f3 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -140,8 +140,11 @@
   if flags.change_concat_input_ranges:
     converter.change_concat_input_ranges = (
         flags.change_concat_input_ranges == "TRUE")
+
   if flags.allow_custom_ops:
     converter.allow_custom_ops = flags.allow_custom_ops
+  if flags.converter_mode:
+    converter.converter_mode = flags.converter_mode
 
   if flags.post_training_quantize:
     converter.post_training_quantize = flags.post_training_quantize
@@ -363,6 +366,8 @@
       help=("Boolean to change behavior of min/max ranges for inputs and "
             "outputs of the concat operator for quantized models. Changes the "
             "ranges of concat operator overlap when true. (default False)"))
+
+  # Permitted ops flags.
   parser.add_argument(
       "--allow_custom_ops",
       action="store_true",
@@ -371,6 +376,12 @@
             "created for any op that is unknown. The developer will need to "
             "provide these to the TensorFlow Lite runtime with a custom "
             "resolver. (default False)"))
+  parser.add_argument(
+      "--converter_mode",
+      type=lite.ConverterMode,
+      choices=list(lite.ConverterMode),
+      help=("Experimental flag, subject to change. ConverterMode indicating "
+            "which converter to use. (default ConverterMode.DEFAULT)"))
 
   # Logging flags.
   parser.add_argument(
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 28a7e50..55bf2c4 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -56,6 +56,20 @@
     srcs = ["schema.fbs"],
 )
 
+# Generic schema for inference on device (but with reflections makes bigger).
+flatbuffer_cc_library(
+    name = "schema_fbs_with_reflection",
+    srcs = ["schema.fbs"],
+    flatc_args = [
+        "--reflect-types",
+        "--reflect-names",
+        "--no-union-value-namespacing",
+        "--gen-object-api",
+    ],
+    gen_reflections = True,
+    out_prefix = "reflection/",
+)
+
 # Schema test to make sure we don't introduce backward incompatible changes
 # to schemas.
 cc_test(
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index cf66403..f0db22d 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -173,6 +173,7 @@
   REDUCE_MIN = 89,
   FLOOR_DIV = 90,
   REDUCE_ANY = 91,
+  SQUARE = 92,
 }
 
 // Options for the builtin operators.
@@ -242,6 +243,7 @@
   LogicalNotOptions,
   UnpackOptions,
   FloorDivOptions,
+  SquareOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -274,11 +276,15 @@
 }
 
 table DepthwiseConv2DOptions {
+  // Parameters for DepthwiseConv version 1 or above.
   padding:Padding;
   stride_w:int;
   stride_h:int;
   depth_multiplier:int;
   fused_activation_function:ActivationFunctionType;
+  // Parameters for DepthwiseConv version 2 or above.
+  dilation_w_factor:int = 1;
+  dilation_h_factor:int = 1;
 }
 
 table ConcatEmbeddingsOptions {
@@ -579,6 +585,9 @@
 table FloorDivOptions {
 }
 
+table SquareOptions {
+}
+
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
 // builtin, or a string if the operator is custom.
 table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 6d9630d..8c086a5 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -226,6 +226,9 @@
 struct FloorDivOptions;
 struct FloorDivOptionsT;
 
+struct SquareOptions;
+struct SquareOptionsT;
+
 struct OperatorCode;
 struct OperatorCodeT;
 
@@ -383,11 +386,12 @@
   BuiltinOperator_REDUCE_MIN = 89,
   BuiltinOperator_FLOOR_DIV = 90,
   BuiltinOperator_REDUCE_ANY = 91,
+  BuiltinOperator_SQUARE = 92,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY
+  BuiltinOperator_MAX = BuiltinOperator_SQUARE
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[92] {
   static BuiltinOperator values[] = {
     BuiltinOperator_ADD,
     BuiltinOperator_AVERAGE_POOL_2D,
@@ -479,7 +483,8 @@
     BuiltinOperator_UNPACK,
     BuiltinOperator_REDUCE_MIN,
     BuiltinOperator_FLOOR_DIV,
-    BuiltinOperator_REDUCE_ANY
+    BuiltinOperator_REDUCE_ANY,
+    BuiltinOperator_SQUARE
   };
   return values;
 }
@@ -578,6 +583,7 @@
     "REDUCE_MIN",
     "FLOOR_DIV",
     "REDUCE_ANY",
+    "SQUARE",
     nullptr
   };
   return names;
@@ -655,11 +661,12 @@
   BuiltinOptions_LogicalNotOptions = 63,
   BuiltinOptions_UnpackOptions = 64,
   BuiltinOptions_FloorDivOptions = 65,
+  BuiltinOptions_SquareOptions = 66,
   BuiltinOptions_MIN = BuiltinOptions_NONE,
-  BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions
+  BuiltinOptions_MAX = BuiltinOptions_SquareOptions
 };
 
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[67] {
   static BuiltinOptions values[] = {
     BuiltinOptions_NONE,
     BuiltinOptions_Conv2DOptions,
@@ -726,7 +733,8 @@
     BuiltinOptions_LogicalAndOptions,
     BuiltinOptions_LogicalNotOptions,
     BuiltinOptions_UnpackOptions,
-    BuiltinOptions_FloorDivOptions
+    BuiltinOptions_FloorDivOptions,
+    BuiltinOptions_SquareOptions
   };
   return values;
 }
@@ -799,6 +807,7 @@
     "LogicalNotOptions",
     "UnpackOptions",
     "FloorDivOptions",
+    "SquareOptions",
     nullptr
   };
   return names;
@@ -1073,6 +1082,10 @@
   static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
 };
 
+template<> struct BuiltinOptionsTraits<SquareOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
+};
+
 struct BuiltinOptionsUnion {
   BuiltinOptions type;
   void *value;
@@ -1624,6 +1637,14 @@
     return type == BuiltinOptions_FloorDivOptions ?
       reinterpret_cast<const FloorDivOptionsT *>(value) : nullptr;
   }
+  SquareOptionsT *AsSquareOptions() {
+    return type == BuiltinOptions_SquareOptions ?
+      reinterpret_cast<SquareOptionsT *>(value) : nullptr;
+  }
+  const SquareOptionsT *AsSquareOptions() const {
+    return type == BuiltinOptions_SquareOptions ?
+      reinterpret_cast<const SquareOptionsT *>(value) : nullptr;
+  }
 };
 
 bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2318,12 +2339,16 @@
   int32_t stride_h;
   int32_t depth_multiplier;
   ActivationFunctionType fused_activation_function;
+  int32_t dilation_w_factor;
+  int32_t dilation_h_factor;
   DepthwiseConv2DOptionsT()
       : padding(Padding_SAME),
         stride_w(0),
         stride_h(0),
         depth_multiplier(0),
-        fused_activation_function(ActivationFunctionType_NONE) {
+        fused_activation_function(ActivationFunctionType_NONE),
+        dilation_w_factor(1),
+        dilation_h_factor(1) {
   }
 };
 
@@ -2334,7 +2359,9 @@
     VT_STRIDE_W = 6,
     VT_STRIDE_H = 8,
     VT_DEPTH_MULTIPLIER = 10,
-    VT_FUSED_ACTIVATION_FUNCTION = 12
+    VT_FUSED_ACTIVATION_FUNCTION = 12,
+    VT_DILATION_W_FACTOR = 14,
+    VT_DILATION_H_FACTOR = 16
   };
   Padding padding() const {
     return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0));
@@ -2351,6 +2378,12 @@
   ActivationFunctionType fused_activation_function() const {
     return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
   }
+  int32_t dilation_w_factor() const {
+    return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
+  }
+  int32_t dilation_h_factor() const {
+    return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
            VerifyField<int8_t>(verifier, VT_PADDING) &&
@@ -2358,6 +2391,8 @@
            VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
            VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) &&
            VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+           VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
+           VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
            verifier.EndTable();
   }
   DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2383,6 +2418,12 @@
   void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
     fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
   }
+  void add_dilation_w_factor(int32_t dilation_w_factor) {
+    fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+  }
+  void add_dilation_h_factor(int32_t dilation_h_factor) {
+    fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+  }
   explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -2401,8 +2442,12 @@
     int32_t stride_w = 0,
     int32_t stride_h = 0,
     int32_t depth_multiplier = 0,
-    ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+    ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+    int32_t dilation_w_factor = 1,
+    int32_t dilation_h_factor = 1) {
   DepthwiseConv2DOptionsBuilder builder_(_fbb);
+  builder_.add_dilation_h_factor(dilation_h_factor);
+  builder_.add_dilation_w_factor(dilation_w_factor);
   builder_.add_depth_multiplier(depth_multiplier);
   builder_.add_stride_h(stride_h);
   builder_.add_stride_w(stride_w);
@@ -5803,6 +5848,46 @@
 
 flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct SquareOptionsT : public flatbuffers::NativeTable {
+  typedef SquareOptions TableType;
+  SquareOptionsT() {
+  }
+};
+
+struct SquareOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef SquareOptionsT NativeTableType;
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           verifier.EndTable();
+  }
+  SquareOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<SquareOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SquareOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  explicit SquareOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  SquareOptionsBuilder &operator=(const SquareOptionsBuilder &);
+  flatbuffers::Offset<SquareOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<SquareOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(
+    flatbuffers::FlatBufferBuilder &_fbb) {
+  SquareOptionsBuilder builder_(_fbb);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct OperatorCodeT : public flatbuffers::NativeTable {
   typedef OperatorCode TableType;
   BuiltinOperator builtin_code;
@@ -6131,6 +6216,9 @@
   const FloorDivOptions *builtin_options_as_FloorDivOptions() const {
     return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast<const FloorDivOptions *>(builtin_options()) : nullptr;
   }
+  const SquareOptions *builtin_options_as_SquareOptions() const {
+    return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast<const SquareOptions *>(builtin_options()) : nullptr;
+  }
   const flatbuffers::Vector<uint8_t> *custom_options() const {
     return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
   }
@@ -6422,6 +6510,10 @@
   return builtin_options_as_FloorDivOptions();
 }
 
+template<> inline const SquareOptions *Operator::builtin_options_as<SquareOptions>() const {
+  return builtin_options_as_SquareOptions();
+}
+
 struct OperatorBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -6996,6 +7088,8 @@
   { auto _e = stride_h(); _o->stride_h = _e; };
   { auto _e = depth_multiplier(); _o->depth_multiplier = _e; };
   { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+  { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; };
+  { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; };
 }
 
 inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7011,13 +7105,17 @@
   auto _stride_h = _o->stride_h;
   auto _depth_multiplier = _o->depth_multiplier;
   auto _fused_activation_function = _o->fused_activation_function;
+  auto _dilation_w_factor = _o->dilation_w_factor;
+  auto _dilation_h_factor = _o->dilation_h_factor;
   return tflite::CreateDepthwiseConv2DOptions(
       _fbb,
       _padding,
       _stride_w,
       _stride_h,
       _depth_multiplier,
-      _fused_activation_function);
+      _fused_activation_function,
+      _dilation_w_factor,
+      _dilation_h_factor);
 }
 
 inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -8661,6 +8759,29 @@
       _fbb);
 }
 
+inline SquareOptionsT *SquareOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new SquareOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void SquareOptions::UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+}
+
+inline flatbuffers::Offset<SquareOptions> SquareOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateSquareOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SquareOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  return tflite::CreateSquareOptions(
+      _fbb);
+}
+
 inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new OperatorCodeT();
   UnPackTo(_o, _resolver);
@@ -9110,6 +9231,10 @@
       auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
       return verifier.VerifyTable(ptr);
     }
+    case BuiltinOptions_SquareOptions: {
+      auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
     default: return false;
   }
 }
@@ -9388,6 +9513,10 @@
       auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
       return ptr->UnPack(resolver);
     }
+    case BuiltinOptions_SquareOptions: {
+      auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
     default: return nullptr;
   }
 }
@@ -9654,6 +9783,10 @@
       auto ptr = reinterpret_cast<const FloorDivOptionsT *>(value);
       return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union();
     }
+    case BuiltinOptions_SquareOptions: {
+      auto ptr = reinterpret_cast<const SquareOptionsT *>(value);
+      return CreateSquareOptions(_fbb, ptr, _rehasher).Union();
+    }
     default: return 0;
   }
 }
@@ -9920,6 +10053,10 @@
       value = new FloorDivOptionsT(*reinterpret_cast<FloorDivOptionsT *>(u.value));
       break;
     }
+    case BuiltinOptions_SquareOptions: {
+      value = new SquareOptionsT(*reinterpret_cast<SquareOptionsT *>(u.value));
+      break;
+    }
     default:
       break;
   }
@@ -10252,6 +10389,11 @@
       delete ptr;
       break;
     }
+    case BuiltinOptions_SquareOptions: {
+      auto ptr = reinterpret_cast<SquareOptionsT *>(value);
+      delete ptr;
+      break;
+    }
     default: break;
   }
   value = nullptr;
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index f738315..45d0d87 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -17,7 +17,7 @@
 
 #include <list>
 #include <memory>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc
similarity index 72%
rename from tensorflow/contrib/lite/error_reporter.cc
rename to tensorflow/contrib/lite/stderr_reporter.cc
index 646913c..e29a634 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/stderr_reporter.cc
@@ -12,7 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
 #include <cstdarg>
 #include <cstdio>
 
@@ -22,26 +22,6 @@
 
 namespace tflite {
 
-ErrorReporter::~ErrorReporter() {}
-
-int ErrorReporter::Report(const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  int code = Report(format, args);
-  va_end(args);
-  return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
-  va_list args;
-  va_start(args, format);
-  int code = Report(format, args);
-  va_end(args);
-  return code;
-}
-
 int StderrReporter::Report(const char* format, va_list args) {
 #ifdef __ANDROID__
   // On Android stderr is not captured for applications, only for code run from
diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h
new file mode 100644
index 0000000..c6f4ffb
--- /dev/null
+++ b/tensorflow/contrib/lite/stderr_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+  int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+}  // namespace tflite
+
+#endif  // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a316a40..b991e99 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -17,7 +17,7 @@
 
 #include <string.h>
 #include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 08a20ce..0a83303 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -42,7 +42,7 @@
 
 #include <vector>
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/string_tflite.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
index d53fec7..a583a91 100644
--- a/tensorflow/contrib/lite/string_util_test.cc
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -15,7 +15,7 @@
 #include "tensorflow/contrib/lite/string_util.h"
 
 #include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/testing/util.h"
 
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 89912fd..a4736bf 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -7,7 +7,7 @@
 load(
     "//tensorflow/contrib/lite:build_def.bzl",
     "gen_zip_test",
-    "generated_test_models",
+    "generated_test_models_all",
 )
 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
 load(
@@ -29,6 +29,7 @@
             "--unzip_binary_path=/usr/bin/unzip",
         ],
     }),
+    conversion_mode = conversion_mode,
     data = [
         ":zip_%s" % test_name,
     ],
@@ -36,7 +37,7 @@
     tags = [
         "gen_zip_test",
         "no_oss",
-        "tflite_not_portable",
+        "tflite_not_portable_intentional",
     ],
     test_name = test_name,
     deps = [
@@ -59,7 +60,7 @@
             "//tensorflow/core:android_tensorflow_test_lib",
         ],
     }),
-) for test_name in generated_test_models()]
+) for conversion_mode, test_name in generated_test_models_all()]
 
 test_suite(
     name = "generated_zip_tests",
@@ -214,6 +215,7 @@
     deps = [
         "//tensorflow/contrib/lite:framework",
         "//tensorflow/contrib/lite:string",
+        "//tensorflow/contrib/lite/core/api",
     ],
 )
 
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 57134cc..3754b58 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -80,7 +80,10 @@
     "--save_graphdefs",
     action="store_true",
     help="Include intermediate graphdefs in the output zip files.")
-
+parser.add_argument(
+    "--run_with_extended",
+    action="store_true",
+    help="Whether the TFLite Extended converter is being used.")
 
 RANDOM_SEED = 342
 TEST_INPUT_DEPTH = 3
@@ -320,10 +323,11 @@
     output tflite model, log_txt from conversion
     or None, log_txt if it did not convert properly.
   """
+  input_arrays = [x[0] for x in input_tensors]
   data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
   opts = toco_options(
       data_types=data_types,
-      input_arrays=[x[0] for x in input_tensors],
+      input_arrays=input_arrays,
       shapes=[x[1] for x in input_tensors],
       output_arrays=output_tensors,
       extra_toco_options=extra_toco_options)
@@ -335,6 +339,11 @@
     graphdef_file.flush()
 
     # TODO(aselle): Switch this to subprocess at some point.
+    if "pb2lite" in bin_path and FLAGS.run_with_extended:
+      opts = ("--input_arrays={0} --output_arrays={1}".format(
+          ",".join(input_arrays), ",".join(output_tensors)))
+    elif FLAGS.run_with_extended:
+      opts += " --allow_eager_ops --force_eager_ops"
     cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
            (bin_path, graphdef_file.name, output_file.name, opts,
             stdout_file.name))
@@ -1425,6 +1434,7 @@
           "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
           "filter_size": [[1, 1], [1, 2], [3, 3]],
           "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+          "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
           "channel_multiplier": [1, 2],
           "rate": [[1, 1]],
           "padding": ["SAME", "VALID"],
@@ -1435,6 +1445,7 @@
           "input_shape": [[1, 3, 4, 3]],
           "filter_size": [[1, 1]],
           "strides": [[1, 1, 2, 1]],  # TF needs [1, x, x, 1]
+          "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]],
           "channel_multiplier": [2],
           "rate": [[2, 2]],  #  Only [1, 1] is supported
           "padding": ["SAME"],
@@ -1502,7 +1513,7 @@
         dtype=tf.float32, name="input", shape=parameters["input_shape"])
     out = tf.split(
         input_tensor, parameters["num_or_size_splits"], parameters["axis"])
-    return [input_tensor], out
+    return [input_tensor], [out[0]]
 
   def build_inputs(parameters, sess, inputs, outputs):
     values = [create_tensor_data(np.float32, parameters["input_shape"])]
@@ -1679,6 +1690,7 @@
 
   # TODO(nupurgarg): Add test for tf.uint8.
   test_parameters = [
+      # 4D:
       {
           "dtype": [tf.int32, tf.int64, tf.float32],
           "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1686,13 +1698,20 @@
                                                           [0, 0], [2, 3]]],
           "constant_paddings": [True, False],
       },
-      # Non-4D use case.
+      # 2D:
       {
           "dtype": [tf.int32, tf.int64, tf.float32],
-          "input_shape": [[1, 2], [0, 1, 2]],
+          "input_shape": [[1, 2]],
           "paddings": [[[0, 1], [2, 3]]],
           "constant_paddings": [True, False],
       },
+      # 1D:
+      {
+          "dtype": [tf.int32],
+          "input_shape": [[1]],
+          "paddings": [[[1, 2]]],
+          "constant_paddings": [False],
+      },
   ]
 
   def build_graph(parameters):
@@ -1730,6 +1749,7 @@
 
   # TODO(nupurgarg): Add test for tf.uint8.
   test_parameters = [
+      # 4D:
       {
           "dtype": [tf.int32, tf.int64, tf.float32],
           "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1738,14 +1758,22 @@
           "constant_paddings": [True, False],
           "constant_values": [0, 2],
       },
-      # Non-4D use case.
+      # 2D:
       {
           "dtype": [tf.int32, tf.int64, tf.float32],
-          "input_shape": [[1, 2], [0, 1, 2]],
+          "input_shape": [[1, 2]],
           "paddings": [[[0, 1], [2, 3]]],
           "constant_paddings": [True, False],
           "constant_values": [0, 2],
       },
+      # 1D:
+      {
+          "dtype": [tf.int32],
+          "input_shape": [[1]],
+          "paddings": [[[0, 1]]],
+          "constant_paddings": [False],
+          "constant_values": [0, 2],
+      },
   ]
 
   def build_graph(parameters):
@@ -2493,10 +2521,12 @@
         shape=parameters["input_shape"])
     if parameters["input_k"] is not None:
       k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+      inputs = [input_value, k]
     else:
       k = tf.constant(3, name="k")
+      inputs = [input_value]
     out = tf.nn.top_k(input_value, k)
-    return [input_value, k], [out[1]]
+    return inputs, [out[1]]
 
   def build_inputs(parameters, sess, inputs, outputs):
     input_value = create_tensor_data(parameters["input_dtype"],
@@ -2854,6 +2884,11 @@
   return _make_elementwise_tests(tf.rsqrt)(zip_path)
 
 
+def make_square_tests(zip_path):
+  """Make a set of tests to do square."""
+  return _make_elementwise_tests(tf.square)(zip_path)
+
+
 def make_where_tests(zip_path):
   """Make a set of tests to do where."""
 
@@ -3191,7 +3226,7 @@
     input_tensor = tf.placeholder(
         dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
     outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
-    return [input_tensor], outs
+    return [input_tensor], [outs[0]]
 
   def build_inputs(parameters, sess, inputs, outputs):
     input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
@@ -3269,7 +3304,11 @@
 
   out = FLAGS.zip_to_output
   bin_path = FLAGS.toco
-  test_function = ("make_%s_tests" % out.replace(".zip", ""))
+  # Some zip filenames contain a postfix identifying the conversion mode. The
+  # list of valid conversion modes is defined in
+  # generated_test_conversion_modes() in build_def.bzl.
+  test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
+      "pb2lite", "").replace("toco-extended", "").rstrip("_")))
   if test_function not in globals():
     raise RuntimeError("Can't find a test function to create %r. Tried %r" %
                        (out, test_function))
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 37c7ae0..349aa5a 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -58,12 +58,6 @@
 // Key is a substring of the test name and value is a bug number.
 // TODO(ahentz): make sure we clean this list up frequently.
 std::map<string, string> kBrokenTests = {
-    // Pad and PadV2 only supports 4D tensors.
-    {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
-     "70527055"},
-    {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
-     "70527055"},
-
     // L2Norm only supports tensors with 4D or fewer.
     {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
 
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 8aa6391..72e4f68 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -17,8 +17,8 @@
 
 #include <cstdio>
 
-#include "tensorflow/contrib/lite/error_reporter.h"
-#include "tensorflow/contrib/lite/string.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/string_tflite.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp
index 109f7d7..6036413 100644
--- a/tensorflow/contrib/lite/tflite_static.bp
+++ b/tensorflow/contrib/lite/tflite_static.bp
@@ -20,17 +20,20 @@
     srcs: [
         "allocation.cc",
         "arena_planner.cc",
-        "context.c",
+        "c/c_api_internal.c",
+        "core/api/error_reporter.cc",
+        "core/api/flatbuffer_conversions.cc",
+        "core/api/op_resolver.cc",
         "delegates/nnapi/nnapi_delegate.cc",
-        "error_reporter.cc",
 	"graph_info.cc",
         "interpreter.cc",
         "mmap_allocation.cc",
         "model.cc",
+        "mutable_op_resolver.cc",
         "nnapi_delegate.cc",
         "optional_debug_tools.cc",
-        "op_resolver.cc",
         "simple_memory_arena.cc",
+        "stderr_reporter.cc",
         "string_util.cc",
         "util.cc",
 	"kernels/elementwise.cc",
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index a75553d..bea90f1 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -372,6 +372,7 @@
         ":toco_graphviz_dump_options",
         ":toco_port",
         ":types_proto_cc",
+        "//tensorflow/contrib/lite/kernels/internal:types",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/strings",
         "@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 84f71dc..f14dbc2 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -247,6 +247,10 @@
   Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
   Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
   Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
+  // WARNING: Experimental interface, subject to change
+  Arg<bool> allow_eager_ops = Arg<bool>(false);
+  // WARNING: Experimental interface, subject to change
+  Arg<bool> force_eager_ops = Arg<bool>(false);
 };
 
 }  // namespace toco
diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
index 262e13a..335debd 100644
--- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
+++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
@@ -1 +1 @@
-<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/></g></svg>
\ No newline at end of file
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m99.50131 100.0l0 76.0l54.992126 0l0 76.0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m99.50131 100.0l0 76.0l54.992126 0l0 72.57292" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m154.49344 248.5729l-1.124588 -1.1245728l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/></g></svg>
\ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 502de88..3114fa9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -63,6 +63,25 @@
   return true;
 }
 
+bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
+  auto& input = model->GetArray(op->inputs[0]);
+  if (input.minmax) {
+    const auto* minmax = input.minmax.get();
+    if (minmax) {
+      return false;
+    }
+  }
+  auto& output = model->GetArray(op->outputs[0]);
+  if (output.minmax) {
+    const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+    if (minmax) {
+      input.GetOrCreateMinMax() = *minmax;
+      return true;
+    }
+  }
+  return false;
+}
+
 bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
   // Do not early return if the output already has min/max:
   // we may still need to adjust the inputs min/max.
@@ -366,6 +385,16 @@
       changed = HardcodeMinMaxForL2Normalization(model, op);
       break;
 
+    case OperatorType::kRelu:
+      // For any normalization other than batch norm, the quantizations ranges
+      // before and after relu are expected to be known. Having a quantization
+      // op before relu would reduce the number of bits of precision for the
+      // activation in half. So we deduce the range before relu from that after
+      // the relu. This would eliminate the need for two fake quantization nodes
+      // and would not reduce the bits of precision available for activation.
+      changed = HardcodeInputMinMaxFromOutput(model, op);
+      break;
+
     case OperatorType::kConcatenation:
       changed = HardcodeMinMaxForConcatenation(model, op);
       break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index c25be07..f103bb9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1314,12 +1314,16 @@
 
   // Compute output shape
   for (int axis = 0; axis < num_input_axes; ++axis) {
+    const auto strided_slice_params =
+        tflite::strided_slice::BuildStridedSliceParams(
+            op->begin_mask, op->end_mask, op->shrink_axis_mask,
+            op->start_indices, op->stop_indices, op->strides);
     int start_index = tflite::strided_slice::StartForAxis(
-        op->begin_mask, op->start_indices, op->strides,
-        input_array.shape().dims().data(), axis);
+        strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
     int stop_index = tflite::strided_slice::StopForAxis(
-        op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
-        input_array.shape().dims().data(), axis, start_index);
+        strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+        start_index);
+
     int dim_size =
         ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
 
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4f..8853ed8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@
   Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
   std::vector<int> src_coord(num_input_axes);
   std::vector<int> stop_for_axis(num_input_axes);
+  const auto strided_slice_params =
+      tflite::strided_slice::BuildStridedSliceParams(
+          op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+          op.stop_indices, op.strides);
+
   for (int axis = 0; axis < num_input_axes; axis++) {
-    int start = tflite::strided_slice::StartForAxis(
-        op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
-        axis);
-    src_coord[axis] = start;
+    int start_index = tflite::strided_slice::StartForAxis(
+        strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+    src_coord[axis] = start_index;
     stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
-        op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
-        input_shape.dims().data(), axis, start);
+        strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+        start_index);
   }
 
   // In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@
       if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
         // Reset axis and set carry
         src_coord[axis] = tflite::strided_slice::StartForAxis(
-            op.begin_mask, op.start_indices, op.strides,
-            input_shape.dims().data(), axis);
+            strided_slice_params, ToRuntimeShape(input_shape), axis);
         carry = true;
       } else {
         carry = false;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index cb6da21..9bc23c4 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2061,8 +2061,14 @@
   }
 
   Model* model = new Model;
-  const internal::ConverterMapType& converter_map =
-      internal::GetTensorFlowNodeConverterMap();
+  internal::ConverterMapType converter_map;
+
+  // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+  // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
+  // converted to TFLite Eager ops.
+  if (!tf_import_flags.import_all_ops_as_unsupported) {
+    converter_map = internal::GetTensorFlowNodeConverterMap();
+  }
 
   for (auto node : inlined_graph.node()) {
     StripZeroOutputIndexFromInputs(&node);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 2177872..7db23f2 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -27,6 +27,11 @@
   // If true, control dependencies will be dropped immediately
   // during the import of the TensorFlow GraphDef.
   bool drop_control_dependency = false;
+
+  // Do not recognize any op and import all ops as
+  // `TensorFlowUnsupportedOperator`. This is used to populated with the
+  // `force_eager_ops` flag.
+  bool import_all_ops_as_unsupported = false;
 };
 
 std::unique_ptr<Model> ImportTensorFlowGraphDef(
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 90e6f69..a00e136 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/lib/core/status.h"
 
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e3..164b70f 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -477,6 +477,11 @@
   int stride_height = 0;
   int stride_width = 0;
   int depth_multiplier = 0;
+  // A dilation_rate of 0 is invalid and this field is an optional attribute.
+  // Thus initializing it to 1 to allow default conv behavior when the
+  // attribute is not present.
+  int dilation_width_factor = 1;
+  int dilation_height_factor = 1;
 };
 
 // Depth-to-space transform operator.
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
index 3761e00..75c1c89 100644
--- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -50,7 +50,7 @@
     toco_flags.output_format = toco_flags_pb2.TFLITE
     toco_flags.inference_input_type = types_pb2.FLOAT
     toco_flags.inference_type = types_pb2.FLOAT
-    toco_flags.allow_custom_ops = True;
+    toco_flags.allow_custom_ops = True
     model_flags = model_flags_pb2.ModelFlags()
     input_array = model_flags.input_arrays.add()
     input_array.name = TensorName(in_tensor)
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index c79469f..fee10b1 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -49,12 +49,21 @@
 
 details::OperatorKey GetOperatorKey(
     const ::toco::Operator& op,
-    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+    bool allow_eager_ops) {
   string custom_code;
   if (op.type == OperatorType::kUnsupported) {
     const TensorFlowUnsupportedOperator& unsupported_op =
         static_cast<const TensorFlowUnsupportedOperator&>(op);
-    custom_code = unsupported_op.tensorflow_op;
+
+    // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+    // to populate a regular custom op. We need to find a way to fix this.
+    if (allow_eager_ops) {
+      custom_code = string(::tflite::kEagerCustomCodePrefix) +
+                    unsupported_op.tensorflow_op;
+    } else {
+      custom_code = unsupported_op.tensorflow_op;
+    }
   }
   int version = 1;
   if (ops_by_type.count(op.type) != 0) {
@@ -91,11 +100,12 @@
 
 void LoadOperatorsMap(
     const Model& model, OperatorsMap* operators_map,
-    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+    bool allow_eager_ops) {
   // First find a list of unique operator types.
   std::set<OperatorKey> keys;
   for (const auto& op : model.operators) {
-    keys.insert(GetOperatorKey(*op, ops_by_type));
+    keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
   }
   // Now assign indices to them and fill in the map.
   int index = 0;
@@ -189,7 +199,7 @@
     const Model& model,
     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
     const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
-    std::set<string>* error_summary) {
+    std::set<string>* error_summary, const ExportParams& params) {
   // Map from operator name to TF Lite enum value, for all builtins.
   std::map<string, BuiltinOperator> builtin_ops;
   for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -205,7 +215,8 @@
   std::map<int, Offset<OperatorCode>> ordered_opcodes;
 
   for (const auto& op : model.operators) {
-    const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
+    const details::OperatorKey operator_key =
+        GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
     int op_index = operators_map.at(operator_key);
     int op_version = operator_key.version;
 
@@ -252,7 +263,7 @@
     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
     const details::OperatorsMap& operators_map,
     const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
-    std::set<int32_t>* variable_tensor_indices) {
+    std::set<int32_t>* variable_tensor_indices, const ExportParams& params) {
   variable_tensor_indices->clear();
 
   // The operators are in execution order, so we just follow tf.mini order.
@@ -269,7 +280,8 @@
       outputs.push_back(tensors_map.at(output));
     }
 
-    int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+    int op_index = operators_map.at(
+        GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
 
     auto tflite_op_it = ops_by_type.find(op->type);
     BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -320,16 +332,15 @@
   return builder->CreateVector(buffer_vector);
 }
 
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
-            string* output_file_contents) {
-  const auto ops_by_type = BuildOperatorByTypeMap();
-  Export(model, allow_custom_ops, quantize_weights, output_file_contents,
-         ops_by_type);
+void Export(const Model& model, string* output_file_contents,
+            const ExportParams& params) {
+  const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+  Export(model, output_file_contents, params, ops_by_type);
 }
 
 void Export(
-    const Model& model, bool allow_custom_ops, bool quantize_weights,
-    string* output_file_contents,
+    const Model& model, string* output_file_contents,
+    const ExportParams& params,
     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
   flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
 
@@ -337,7 +348,8 @@
   details::LoadTensorsMap(model, &tensors_map);
 
   details::OperatorsMap operators_map;
-  details::LoadOperatorsMap(model, &operators_map, ops_by_type);
+  details::LoadOperatorsMap(model, &operators_map, ops_by_type,
+                            params.allow_eager_ops);
 
   std::vector<const Array*> buffers_to_write;
   Array empty_array;
@@ -345,7 +357,7 @@
 
   std::set<string> error_summary;
   auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
-                                      &builder, &error_summary);
+                                      &builder, &error_summary, params);
 
   for (const auto& op : model.operators) {
     if (op->type == OperatorType::kFakeQuant) {
@@ -355,7 +367,7 @@
                       "for --std_values and --mean_values.";
     }
   }
-  if (!allow_custom_ops && !error_summary.empty()) {
+  if (!params.allow_custom_ops && !error_summary.empty()) {
     // Remove ExpandDims and ReorderAxes from unimplemented list unless they
     // compose the list. Both ops are removed during graph transformations.
     // However, if an op is unimplemented earlier in the model, the graph
@@ -383,7 +395,7 @@
 
   std::set<int32_t> variable_tensor_indices;
   auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
-                             &builder, &variable_tensor_indices);
+                             &builder, &variable_tensor_indices, params);
 
   auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
                                variable_tensor_indices);
@@ -402,7 +414,7 @@
                   builder.CreateVector(subgraphs), description, buffers);
   ::tflite::FinishModelBuffer(builder, new_model_location);
 
-  if (quantize_weights) {
+  if (params.quantize_weights) {
     // Call the quantize_weights tool.
     LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
                  "dump_graphviz will only output the model before this "
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 915d5dd..b070a38 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -23,22 +23,54 @@
 
 namespace tflite {
 
+// The parameters for exporting a TFLite model.
+struct ExportParams {
+  bool allow_custom_ops = false;
+  bool allow_eager_ops = false;
+  bool quantize_weights = false;
+};
+
 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
 // result in the given string.
-void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
-            string* output_file_contents);
-
-// This if backward-compatibility.
-// TODO(ycling): Remove the deprecated entry functions.
-inline void Export(const Model& model, string* output_file_contents) {
-  Export(model, true, false, output_file_contents);
-}
+void Export(const Model& model, string* output_file_contents,
+            const ExportParams& params);
 
 // Export API with custom TFLite operator mapping.
 void Export(
+    const Model& model, string* output_file_contents,
+    const ExportParams& params,
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, bool allow_custom_ops,
+                   bool quantize_weights, string* output_file_contents) {
+  ExportParams params;
+  params.allow_custom_ops = allow_custom_ops;
+  params.quantize_weights = quantize_weights;
+  Export(model, output_file_contents, params);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(
     const Model& model, bool allow_custom_ops, bool quantize_weights,
     string* output_file_contents,
-    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+  ExportParams params;
+  params.allow_custom_ops = allow_custom_ops;
+  params.quantize_weights = quantize_weights;
+  Export(model, output_file_contents, params, ops_by_type);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, string* output_file_contents) {
+  ExportParams params;
+  params.allow_custom_ops = true;
+  Export(model, output_file_contents, params);
+  Export(model, true, false, output_file_contents);
+}
 
 namespace details {
 
@@ -88,7 +120,8 @@
 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
 void LoadOperatorsMap(
     const Model& model, OperatorsMap* operators_map,
-    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+    const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+    bool allow_eager_ops);
 
 }  // namespace details
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 4994ea3..8d4d197 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,7 +105,8 @@
 
   details::OperatorsMap operators;
   const auto ops_by_type = BuildOperatorByTypeMap();
-  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+  // TODO(ycling): Add a test for allow_eager_ops.
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
   EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
   EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
   EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
@@ -253,7 +254,7 @@
 
   details::OperatorsMap operators;
   const auto ops_by_type = BuildFakeOperatorByTypeMap();
-  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
 
   EXPECT_EQ(1, operators.size());
   EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -264,7 +265,7 @@
 
   details::OperatorsMap operators;
   const auto ops_by_type = BuildFakeOperatorByTypeMap();
-  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
 
   EXPECT_EQ(1, operators.size());
   EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
@@ -276,7 +277,7 @@
 
   details::OperatorsMap operators;
   const auto ops_by_type = BuildFakeOperatorByTypeMap();
-  details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+  details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
 
   EXPECT_EQ(2, operators.size());
   EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index a314c8d..1061e7c 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@
         ActivationFunction::Serialize(op.fused_activation_function);
     return ::tflite::CreateDepthwiseConv2DOptions(
         *builder, padding, op.stride_width, op.stride_height,
-        op.depth_multiplier, activation_function);
+        op.depth_multiplier, activation_function, op.dilation_width_factor,
+        op.dilation_height_factor);
   }
 
   void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@
     op->depth_multiplier = options.depth_multiplier();
     op->fused_activation_function =
         ActivationFunction::Deserialize(options.fused_activation_function());
+    op->dilation_width_factor = options.dilation_w_factor();
+    op->dilation_height_factor = options.dilation_h_factor();
   }
 
-  int GetVersion(const Operator& op) const override { return 1; }
+  int GetVersion(const Operator& op) const override {
+    const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+    if (conv_op.dilation_width_factor != 1 ||
+        conv_op.dilation_height_factor != 1) {
+      return 2;
+    }
+    return 1;
+  }
 };
 
 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
@@ -1149,7 +1159,9 @@
 
 class TensorFlowUnsupported : public BaseOperator {
  public:
-  using BaseOperator::BaseOperator;
+  TensorFlowUnsupported(const string& name, OperatorType type,
+                        bool allow_eager_ops)
+      : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
 
   Options Serialize(const Operator& op,
                     flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1165,6 +1177,9 @@
   std::unique_ptr<Operator> Deserialize(
       const BuiltinOptions* builtin_options,
       const CustomOptions* custom_options) const override {
+    // Deserializing Eager ops doesn't work now.
+    // TODO(ycling): Revisit and decide if we should fix the flow for importing
+    // TFLite models with Eager ops.
     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
     if (custom_options) {
       auto flexbuffer_map =
@@ -1185,6 +1200,16 @@
       return std::unique_ptr<flexbuffers::Builder>();
     }
 
+    if (allow_eager_ops_) {
+      fbb->Vector([&]() {
+        fbb->String(node_def.op());
+        fbb->String(op.tensorflow_node_def);
+      });
+      fbb->Finish();
+      LOG(INFO) << "Writing eager op: " << node_def.op();
+      return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+    }
+
     bool has_valid_attr = false;
     size_t map_start = fbb->StartMap();
     for (const auto& pair : node_def.attr()) {
@@ -1285,11 +1310,15 @@
     // custom ops.
     return 1;
   }
+
+ private:
+  const bool allow_eager_ops_;
 };
 
 namespace {
 // Build a vector containing all the known operators.
-std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
+    bool allow_eager_ops = false) {
   std::vector<std::unique_ptr<BaseOperator>> ops;
   using tensorflow::MakeUnique;
   // Builtin Operators.
@@ -1400,8 +1429,8 @@
       MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
-  ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
-                                                  OperatorType::kUnsupported));
+  ops.push_back(MakeUnique<TensorFlowUnsupported>(
+      "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
 
   // There operators are supported by Toco, but not by TF Lite, and has no
   // attributes.
@@ -1469,15 +1498,19 @@
       "SQRT", OperatorType::kSqrt));
   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
       "RSQRT", OperatorType::kRsqrt));
+  ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
+      "SQUARE", OperatorType::kSquare));
 
   return ops;
 }
 }  // namespace
 
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+    bool allow_eager_ops) {
   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
 
-  std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+  std::vector<std::unique_ptr<BaseOperator>> ops =
+      BuildOperatorList(allow_eager_ops);
   for (auto& op : ops) {
     result[op->type()] = std::move(op);
   }
@@ -1485,10 +1518,12 @@
   return result;
 }
 
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+    bool allow_eager_ops) {
   std::map<string, std::unique_ptr<BaseOperator>> result;
 
-  std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+  std::vector<std::unique_ptr<BaseOperator>> ops =
+      BuildOperatorList(allow_eager_ops);
   for (auto& op : ops) {
     result[op->name()] = std::move(op);
   }
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index d9ea23e..702fb28 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,11 +26,15 @@
 class BaseOperator;
 
 // Return a map contained all know TF Lite Operators, keyed by their names.
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// is ugly here. Consider refactoring.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+    bool allow_eager_ops = false);
 
 // Return a map contained all know TF Lite Operators, keyed by the type of
 // their tf.mini counterparts.
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+    bool allow_eager_ops = false);
 
 // These are the flatbuffer types for custom and builtin options.
 using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 519a3a4..72e50a9 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -144,6 +144,8 @@
   CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
                                           OperatorType::kLogicalNot);
   CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
+  CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
+                                                OperatorType::kSquare);
 }
 
 TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index f83a290..b6aebc0 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -165,7 +165,13 @@
            parsed_flags.post_training_quantize.default_value(),
            "Boolean indicating whether to quantize the weights of the "
            "converted float model. Model size will be reduced and there will "
-           "be latency improvements (at the cost of accuracy).")};
+           "be latency improvements (at the cost of accuracy)."),
+      // WARNING: Experimental interface, subject to change
+      Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
+           parsed_flags.allow_eager_ops.default_value(), ""),
+      // WARNING: Experimental interface, subject to change
+      Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
+           parsed_flags.force_eager_ops.default_value(), "")};
   bool asked_for_help =
       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
   if (asked_for_help) {
@@ -260,6 +266,16 @@
   READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
   READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
   READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
+  READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
+  READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+
+  if (parsed_toco_flags.force_eager_ops.value() &&
+      !parsed_toco_flags.allow_eager_ops.value()) {
+    // TODO(ycling): Consider to enforce `allow_eager_ops` when
+    // `force_eager_ops` is true.
+    LOG(WARNING) << "--force_eager_ops should always be used with "
+                    "--allow_eager_ops.";
+  }
 
   // Deprecated flag handling.
   if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index c1dd621..53d60fe 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@
 // of as properties of models, instead describing how models are to be
 // processed in the context of the present tooling job.
 //
-// Next ID to use: 27.
+// Next ID to use: 29.
 message TocoFlags {
   // Input file format
   optional FileFormat input_format = 1;
@@ -189,4 +189,17 @@
   // model. Model size will be reduced and there will be latency improvements
   // (at the cost of accuracy).
   optional bool post_training_quantize = 26 [default = false];
+
+  // When enabled, unsupported ops will be converted to TFLite Eager ops.
+  // TODO(ycling): Consider to rename the following 2 flags and don't call it
+  // "Eager".
+  // `allow_eager_ops` should always be used with `allow_custom_ops`.
+  // WARNING: Experimental interface, subject to change
+  optional bool allow_eager_ops = 27 [default = false];
+
+  // When enabled, all TensorFlow ops will be converted to TFLite Eager
+  // ops directly. This will force `allow_eager_ops` to true.
+  // `force_eager_ops` should always be used with `allow_eager_ops`.
+  // WARNING: Experimental interface, subject to change
+  optional bool force_eager_ops = 28 [default = false];
 }
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 7db7acb..a7c1715 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -197,6 +197,10 @@
           toco_flags.has_drop_control_dependency()
               ? toco_flags.drop_control_dependency()
               : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
+
+      tf_import_flags.import_all_ops_as_unsupported =
+          toco_flags.force_eager_ops();
+
       model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
                                        input_file_contents);
       break;
@@ -397,11 +401,21 @@
     case TENSORFLOW_GRAPHDEF:
       ExportTensorFlowGraphDef(model, output_file_contents);
       break;
-    case TFLITE:
-      toco::tflite::Export(model, allow_custom_ops,
-                           toco_flags.post_training_quantize(),
-                           output_file_contents);
-      break;
+    case TFLITE: {
+      toco::tflite::ExportParams params;
+
+      // Always allow custom ops when eager ops are allowed.
+      if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
+        params.allow_eager_ops = true;
+        params.allow_custom_ops = true;
+      } else if (allow_custom_ops) {
+        params.allow_custom_ops = true;
+      }
+
+      params.quantize_weights = toco_flags.post_training_quantize();
+
+      toco::tflite::Export(model, output_file_contents, params);
+    } break;
     case GRAPHVIZ_DOT:
       DumpGraphviz(model, output_file_contents);
       break;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index bdeb203..5f4b8cb 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@
 #if TOCO_SUPPORT_PORTABLE_PROTOS
 #include "third_party/protobuf/include/google/protobuf/text_format.h"
 #endif  // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
 #include "tensorflow/contrib/lite/toco/model.h"
 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
 #include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@
 // - For the remaining indices [0..i0), d0[i0] == 1.
 bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
 
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+  return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
 bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
 
 // If there is a wildcard dimension (-1), this may return a negative value.
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
index a66812f..98e2835 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -54,6 +54,7 @@
     linkopts = common_linkopts,
     linkstatic = 1,
     tags = [
+        "no_oss",  # b/114307765
         "tflite_not_portable_android",
         "tflite_not_portable_ios",
     ],
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 0203992..ef4f0fa 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -232,6 +232,46 @@
   return total_input_bytes;
 }
 
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+  auto interpreter_inputs = interpreter->inputs();
+  // Set the values of the input tensors.
+  for (int j = 0; j < inputs.size(); ++j) {
+    const InputLayerInfo& input = inputs[j];
+    int i = interpreter_inputs[j];
+    TfLiteTensor* t = interpreter->tensor(i);
+    std::vector<int> sizes = input.shape;
+
+    // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+    if (t->type == kTfLiteFloat32) {
+      FillRandomValue<float>(
+          interpreter->typed_tensor<float>(i),
+          std::vector<int>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+    } else if (t->type == kTfLiteInt32) {
+      // TODO(yunluli): This is currently only used for handling embedding input
+      // for speech models. Generalize if necessary.
+      FillRandomValue<int32_t>(
+          interpreter->typed_tensor<int32_t>(i),
+          std::vector<int32_t>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<int32_t>(rand()) % 100; });
+    } else if (t->type == kTfLiteUInt8) {
+      FillRandomValue<uint8_t>(
+          interpreter->typed_tensor<uint8_t>(i),
+          std::vector<int>(sizes.begin() + 1, sizes.end()),
+          []() { return static_cast<uint8_t>(rand()) % 255; });
+    } else if (t->type == kTfLiteString) {
+      tflite::DynamicBuffer buffer;
+      FillRandomString(&buffer, sizes, []() {
+        return "we're have some friends over saturday to hang out in the yard";
+      });
+      buffer.WriteToTensor(interpreter->tensor(i));
+    } else {
+      TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+                        << " of type " << t->type;
+    }
+  }
+}
+
 void BenchmarkTfLiteModel::Init() {
   std::string graph = params_.Get<std::string>("graph");
   model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -305,36 +345,6 @@
   if (interpreter->AllocateTensors() != kTfLiteOk) {
     TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
   }
-
-  // Set the values of the input tensors.
-  for (int j = 0; j < inputs.size(); ++j) {
-    const InputLayerInfo& input = inputs[j];
-    int i = interpreter_inputs[j];
-    TfLiteTensor* t = interpreter->tensor(i);
-    std::vector<int> sizes = input.shape;
-
-    // TODO(ahentz): below we ignore the O-th dimension (number of batches).
-    if (t->type == kTfLiteFloat32) {
-      FillRandomValue<float>(
-          interpreter->typed_tensor<float>(i),
-          std::vector<int>(sizes.begin() + 1, sizes.end()),
-          []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
-    } else if (t->type == kTfLiteUInt8) {
-      FillRandomValue<uint8_t>(
-          interpreter->typed_tensor<uint8_t>(i),
-          std::vector<int>(sizes.begin() + 1, sizes.end()),
-          []() { return static_cast<uint8_t>(rand()) % 255; });
-    } else if (t->type == kTfLiteString) {
-      tflite::DynamicBuffer buffer;
-      FillRandomString(&buffer, sizes, []() {
-        return "we're have some friends over saturday to hang out in the yard";
-      });
-      buffer.WriteToTensor(interpreter->tensor(i));
-    } else {
-      TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
-                        << " of type " << t->type;
-    }
-  }
 }
 
 void BenchmarkTfLiteModel::RunImpl() {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 4c4320a..8541512 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -69,6 +69,9 @@
     std::vector<int> shape;
   };
 
+ protected:
+  void PrepareInputsAndOutputs() override;
+
  private:
 #ifdef TFLITE_EXTENDED
   std::unique_ptr<EagerDelegate> delegate_;
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index e30cc1d..59bdb10 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -24,6 +24,21 @@
 TARGET := $(HOST_OS)
 TARGET_ARCH := $(HOST_ARCH)
 
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
+-I$(MAKEFILE_DIR)/downloads/farmhash/src \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
 # These are the default libraries needed, but they can be added to or
 # overridden by the platform-specific settings in target makefiles.
 LIBS := \
@@ -44,55 +59,17 @@
 TARGET_TOOLCHAIN_PREFIX :=
 CC_PREFIX :=
 
-# These target-specific makefiles should modify or replace options like
-# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
-# based on platforms or architectures should happen within these files, to
-# keep this main makefile focused on the sources and dependencies.
-include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
-
-# Where compiled objects are stored.
-GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
-OBJDIR := $(GENDIR)obj/
-BINDIR := $(GENDIR)bin/
-LIBDIR := $(GENDIR)lib/
-
-INCLUDES := \
--I. \
--I$(MAKEFILE_DIR)/../../../../../ \
--I$(MAKEFILE_DIR)/../../../../../../ \
--I$(MAKEFILE_DIR)/downloads/ \
--I$(MAKEFILE_DIR)/downloads/eigen \
--I$(MAKEFILE_DIR)/downloads/gemmlowp \
--I$(MAKEFILE_DIR)/downloads/neon_2_sse \
--I$(MAKEFILE_DIR)/downloads/farmhash/src \
--I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(OBJDIR)
-# This is at the end so any globally-installed frameworks like protobuf don't
-# override local versions in the source tree.
-INCLUDES += -I/usr/local/include
-
-CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
-CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
-
 # This library is the main target for this makefile. It will contain a minimal
 # runtime that can be linked in to other programs.
 LIB_NAME := libtensorflow-lite.a
-LIB_PATH := $(LIBDIR)$(LIB_NAME)
-
-# A small example program that shows how to link against the library.
-MINIMAL_PATH := $(BINDIR)minimal
 
 # Benchmark static library and binary
 BENCHMARK_LIB_NAME := benchmark-lib.a
 BENCHMARK_BINARY_NAME := benchmark_model
-BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
-BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
 
+# A small example program that shows how to link against the library.
 MINIMAL_SRCS := \
 tensorflow/contrib/lite/examples/minimal/minimal.cc
-MINIMAL_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
 
 # What sources we want to compile, must be kept in sync with the main Bazel
 # build files.
@@ -105,7 +82,9 @@
 
 CORE_CC_ALL_SRCS := \
 $(wildcard tensorflow/contrib/lite/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c)
+$(wildcard tensorflow/contrib/lite/*.c) \
+$(wildcard tensorflow/contrib/lite/c/*.c) \
+$(wildcard tensorflow/contrib/lite/core/api/*.cc)
 ifneq ($(BUILD_TYPE),micro)
 CORE_CC_ALL_SRCS += \
 $(wildcard tensorflow/contrib/lite/kernels/*.cc) \
@@ -136,10 +115,6 @@
 endif
 # Filter out all the excluded files.
 TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
-# File names of the intermediate files target compilation generates.
-TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
-LIB_OBJS := $(TF_LITE_CC_OBJS)
 
 # Benchmark sources
 BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
@@ -151,6 +126,40 @@
 	$(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
     $(BENCHMARK_ALL_SRCS))
 
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+	$(MINIMAL_SRCS) \
+	$(PROFILER_SRCS) \
+	$(PROFILER_SUMMARY_SRCS) \
+	$(TF_LITE_CC_SRCS) \
+	$(BENCHMARK_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+LIB_PATH := $(LIBDIR)$(LIB_NAME)
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+MINIMAL_BINARY := $(BINDIR)minimal
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
+
+LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
+
 BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
 $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
 
@@ -164,7 +173,7 @@
 	$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
 
 # The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH)  $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+all: $(LIB_PATH)  $(MINIMAL_BINARY) $(BENCHMARK_BINARY)
 
 # The target that's compiled for micro-controllers
 micro: $(LIB_PATH)
@@ -178,19 +187,18 @@
 	@mkdir -p $(dir $@)
 	$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
 
-$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
+$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH)
 	@mkdir -p $(dir $@)
 	$(CXX) $(CXXFLAGS) $(INCLUDES) \
-	-o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
+	-o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \
 	$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
 
-
 $(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
 	@mkdir -p $(dir $@)
 	$(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
 
 benchmark_lib: $(BENCHMARK_LIB)
-$(info $(BENCHMARK_BINARY))
+
 $(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
 	@mkdir -p $(dir $@)
 	$(CXX) $(CXXFLAGS) $(INCLUDES) \
@@ -213,4 +221,4 @@
 $(DEPDIR)/%.d: ;
 .PRECIOUS: $(DEPDIR)/%.d
 
--include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index e0ed7c7..d02d78b 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -42,10 +42,9 @@
   bool eval_hybrid;
 } TensorInfo;
 
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this configurable.
-const int kWeightsMinSize = 1024;
+// The default minimum number of elements a weights array must have to be
+// quantized by this transformation.
+const int kWeightsMinNumElementsDefault = 1024;
 
 // Nudge min and max so that floating point 0 falls exactly on a quantized
 // value, returning the nudges scale and zero_point.
@@ -142,6 +141,7 @@
       op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
       op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
       op_code == BuiltinOperator_RNN ||
+      op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
       op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
       op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
       op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
@@ -158,43 +158,59 @@
 
 // Returns a vector of TensorInfos for each input tensor of op that should be
 // quantized.
-std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model,
-                                                          const OperatorT* op) {
+std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
+    const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
+    bool use_hybrid_evaluation) {
   SubGraphT* subgraph = model->subgraphs.at(0).get();
   const BuiltinOperator op_code =
       model->operator_codes[op->opcode_index]->builtin_code;
 
   std::vector<TensorInfo> tensor_infos;
 
-  bool eval_hybrid = IsHybridEvaluationOp(op, op_code);
+  bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code);
 
-  bool skipped_tensor = false;
   std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
   for (const int32_t op_input_idx : op_input_indices) {
     int32_t tensor_idx = op->inputs[op_input_idx];
 
-    // TODO(suharshs): Support shared weights, i.e. If two tensors share the
-    // same weight array, things may break. (i.e. SSD object detection)
-    if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
-      LOG(INFO) << "Skipping quantization of tensor that is shared between "
-                   "multiple multiple operations.";
-      skipped_tensor = true;
+    if (tensor_idx == -1) {
+      LOG(INFO) << "Skipping optional tensor input " << op_input_idx
+                << " of operation " << EnumNameBuiltinOperator(op_code);
       continue;
     }
 
     TensorT* tensor = subgraph->tensors[tensor_idx].get();
+    // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+    // same weight array, things may break. (i.e. SSD object detection)
+    if (!eval_hybrid &&
+        CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
+      LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+                << " that is shared between multiple multiple operations.";
+      continue;
+    }
 
     if (tensor->type != TensorType_FLOAT32) {
-      LOG(INFO) << "Skipping quantization of tensor that is not type float.";
-      skipped_tensor = true;
+      LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+                << " that is not type float.";
       continue;
     }
 
     const uint64_t num_elements = NumElements(tensor);
-    if (num_elements < kWeightsMinSize) {
-      LOG(INFO) << "Skipping quantization of tensor because it has fewer than "
-                << kWeightsMinSize << " elements (" << num_elements << ").";
-      skipped_tensor = true;
+    if (num_elements < weights_min_num_elements) {
+      LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+                << " because it has fewer than " << weights_min_num_elements
+                << " elements (" << num_elements << ").";
+      // If one of the weights isn't quantized, then we cannot use the hybrid
+      // kernel for this operation, since it expects everything to be quantized.
+      eval_hybrid = false;
+      continue;
+    }
+
+    // Some tensors may have a null buffer vector, indicating an intermediate
+    // array.
+    if (model->buffers[tensor->buffer]->data.data() == nullptr) {
+      LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+                << " because it has no allocated buffer.";
       continue;
     }
 
@@ -207,12 +223,6 @@
     tensor_infos.push_back(tensor_info);
   }
 
-  // For hybrid operations we either need to quantize all tensors or none. So
-  // if we skipped any tensors we need to return no quantized tensors.
-  if (eval_hybrid && skipped_tensor) {
-    return {};
-  }
-
   return tensor_infos;
 }
 
@@ -331,11 +341,10 @@
   tensor->reset(tensor_raw);
 }
 
-}  // namespace
-
-TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
-                             const Model* input_model,
-                             bool use_hybrid_evaluation) {
+TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
+                                     const Model* input_model,
+                                     bool use_hybrid_evaluation,
+                                     uint64_t weights_min_num_elements) {
   std::unique_ptr<ModelT> model;
   model.reset(input_model->UnPack());
 
@@ -352,11 +361,11 @@
   for (int i = 0; i < subgraph->operators.size(); ++i) {
     OperatorT* op = subgraph->operators[i].get();
 
-    std::vector<TensorInfo> tensor_infos =
-        GetQuantizableTensorsFromOperator(model.get(), op);
+    std::vector<TensorInfo> tensor_infos = GetQuantizableTensorsFromOperator(
+        model.get(), op, weights_min_num_elements, use_hybrid_evaluation);
 
     for (const TensorInfo& tensor_info : tensor_infos) {
-      if (use_hybrid_evaluation && tensor_info.eval_hybrid) {
+      if (tensor_info.eval_hybrid) {
         // Quantize the tensor.
         TF_LITE_ENSURE_STATUS(
             SymmetricQuantizeTensor(model.get(), tensor_info.tensor));
@@ -399,9 +408,32 @@
   return kTfLiteOk;
 }
 
+}  // namespace
+
+namespace internal {
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+                             const Model* input_model,
+                             bool use_hybrid_evaluation) {
+  // By default we require that only weights with more than
+  // kWeightsMinSizeDefault elements are quantized.
+  return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
+                                 kWeightsMinNumElementsDefault);
+}
+}  // namespace internal
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+                             const Model* input_model,
+                             uint64_t weights_min_num_elements) {
+  return QuantizeWeightsInternal(builder, input_model, true,
+                                 weights_min_num_elements);
+}
+
 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
                              const Model* input_model) {
-  return QuantizeWeights(builder, input_model, true);
+  // By default we require that only weights with more than
+  // kWeightsMinSizeDefault elements are quantized.
+  return QuantizeWeightsInternal(builder, input_model, true,
+                                 kWeightsMinNumElementsDefault);
 }
 
 }  // namespace optimize
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
index 3743c0c..706f10b 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -25,6 +25,8 @@
 namespace optimize {
 
 // Quantizes input_model and populates the provided builder with the new model.
+// By default only weights tensors weight more than 1024 elements will be
+// quantized.
 //
 // A tflite::Model can be obtained from the builder with:
 //   const uint8_t* buffer = builder->GetBufferPointer();
@@ -32,11 +34,22 @@
 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
                              const Model* input_model);
 
-// Same as above, but if use_hybrid_evaluation is false, will disable using
-// hybrid eval for operations that support it.
+// Same as above, but only weights with greater than or equal
+// weights_min_num_elements elements will be quantized.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+                             const Model* input_model,
+                             uint64_t weights_min_num_elements);
+
+namespace internal {
+// If use_hybrid_evaluation is false, will disable using hybrid eval for
+// operations that support it.
+//
+// We use this internal QuantizeWeights call to test models with hybrid
+// evaluation disabled.
 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
                              const Model* input_model,
                              bool use_hybrid_evaluation);
+}  // namespace internal
 
 }  // namespace optimize
 }  // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
index efaf992..387b347 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -76,7 +76,8 @@
 
   void CheckWeights(const Model* input_model_packed,
                     const Model* output_model_packed,
-                    bool use_hybrid_evaluation) {
+                    bool use_hybrid_evaluation,
+                    uint64_t weights_min_num_elements = 1024) {
     std::unique_ptr<ModelT> input_model;
     input_model.reset(input_model_packed->UnPack());
 
@@ -113,8 +114,9 @@
       int tensor_size = GetElementsNum(tensor);
       // If the tensor_size is less than 1024 we expect the tensor to remain
       // unquantized.
-      if (tensor_size < 1024) {
-        ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+      if (tensor_size < weights_min_num_elements) {
+        ASSERT_TRUE(tensor->type == TensorType_FLOAT32)
+            << tensor->name << " of type " << tensor->type;
         const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
         // The weight tensor should not come from a dequantize op.
         ASSERT_TRUE(preceding_op == nullptr);
@@ -183,7 +185,7 @@
 
   flatbuffers::FlatBufferBuilder builder;
   // Disable hybrid evaluation.
-  EXPECT_EQ(QuantizeWeights(&builder, input_model, false), kTfLiteOk);
+  EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk);
 
   const uint8_t* buffer = builder.GetBufferPointer();
   const Model* output_model = GetModel(buffer);
@@ -191,6 +193,26 @@
   CheckWeights(input_model, output_model, false);
 }
 
+TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) {
+  string model_path =
+      "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+      "mobilenet_v1_0.25_128.tflite";
+  std::unique_ptr<FlatBufferModel> input_fb =
+      FlatBufferModel::BuildFromFile(model_path.data());
+  const Model* input_model = input_fb->GetModel();
+
+  flatbuffers::FlatBufferBuilder builder;
+  // Make weights_min_size sufficiently large such that no quantization should
+  // happen, i.e. the original model is the same size as the old one.
+  const uint64_t kWeightsMinNumElements = 1000000;
+  EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements),
+            kTfLiteOk);
+
+  const uint8_t* buffer = builder.GetBufferPointer();
+  const Model* output_model = GetModel(buffer);
+  CheckWeights(input_model, output_model, true, kWeightsMinNumElements);
+}
+
 // TODO(suharshs): Add tests that run the resulting model.
 
 }  // namespace
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index 597dede..d7eea79 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -202,7 +202,7 @@
       html += str(i) + " "
       html += tensor["name"] + " "
       html += str(tensor["type"]) + " "
-      html += repr(tensor["shape"]) + "<br>"
+      html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + "<br>"
     html += "</span>"
     html += repr(x)
     html += "</span>"
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000..67ff1ea
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+    name = "mnist_tflite",
+    srcs = [
+        "dataset.py",
+        "mnist_tflite.py",
+    ],
+    deps = [
+        "//tensorflow:tensorflow_py",
+    ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000..ba49dfc
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+#  Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+  """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+  dt = np.dtype(np.uint32).newbyteorder('>')
+  return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+  """Validate that filename corresponds to images for the MNIST dataset."""
+  with tf.gfile.Open(filename, 'rb') as f:
+    magic = read32(f)
+    read32(f)  # num_images, unused
+    rows = read32(f)
+    cols = read32(f)
+    if magic != 2051:
+      raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+                                                                     f.name))
+    if rows != 28 or cols != 28:
+      raise ValueError(
+          'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+          (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+  """Validate that filename corresponds to labels for the MNIST dataset."""
+  with tf.gfile.Open(filename, 'rb') as f:
+    magic = read32(f)
+    read32(f)  # num_items, unused
+    if magic != 2049:
+      raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+                                                                     f.name))
+
+
+def download(directory, filename):
+  """Download (and unzip) a file from the MNIST dataset if not already done."""
+  filepath = os.path.join(directory, filename)
+  if tf.gfile.Exists(filepath):
+    return filepath
+  if not tf.gfile.Exists(directory):
+    tf.gfile.MakeDirs(directory)
+  # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+  url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+  _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+  print('Downloading %s to %s' % (url, zipped_filepath))
+  urllib.request.urlretrieve(url, zipped_filepath)
+  with gzip.open(zipped_filepath, 'rb') as f_in, \
+      tf.gfile.Open(filepath, 'wb') as f_out:
+    shutil.copyfileobj(f_in, f_out)
+  os.remove(zipped_filepath)
+  return filepath
+
+
+def dataset(directory, images_file, labels_file):
+  """Download and parse MNIST dataset."""
+
+  images_file = download(directory, images_file)
+  labels_file = download(directory, labels_file)
+
+  check_image_file_header(images_file)
+  check_labels_file_header(labels_file)
+
+  def decode_image(image):
+    # Normalize from [0, 255] to [0.0, 1.0]
+    image = tf.decode_raw(image, tf.uint8)
+    image = tf.cast(image, tf.float32)
+    image = tf.reshape(image, [784])
+    return image / 255.0
+
+  def decode_label(label):
+    label = tf.decode_raw(label, tf.uint8)  # tf.string -> [tf.uint8]
+    label = tf.reshape(label, [])  # label is a scalar
+    return tf.to_int32(label)
+
+  images = tf.data.FixedLengthRecordDataset(
+      images_file, 28 * 28, header_bytes=16).map(decode_image)
+  labels = tf.data.FixedLengthRecordDataset(
+      labels_file, 1, header_bytes=8).map(decode_label)
+  return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+  """tf.data.Dataset object for MNIST training data."""
+  return dataset(directory, 'train-images-idx3-ubyte',
+                 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+  """tf.data.Dataset object for MNIST test data."""
+  return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000..7b8bf5b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf  # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+                    'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+                    'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+  # Generates an iterator over images
+  with tf.Session() as sess:
+    input_data = dataset.test(
+        flags.data_dir).make_one_shot_iterator().get_next()
+    try:
+      while True:
+        yield sess.run(input_data)
+    except tf.errors.OutOfRangeError:
+      pass
+
+
+def run_eval(interpreter, input_image):
+  """Performs evaluation for input image over specified model.
+
+  Args:
+      interpreter: TFLite interpreter initialized with model to execute.
+      input_image: Image input to the model.
+
+  Returns:
+      output: output tensor of model being executed.
+  """
+
+  # Get input and output tensors.
+  input_details = interpreter.get_input_details()
+  output_details = interpreter.get_output_details()
+
+  # Test model on the input images.
+  input_image = np.reshape(input_image, input_details[0]['shape'])
+  interpreter.set_tensor(input_details[0]['index'], input_image)
+
+  interpreter.invoke()
+  output_data = interpreter.get_tensor(output_details[0]['index'])
+  output = np.squeeze(output_data)
+  return output
+
+
+def main(_):
+  interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+  interpreter.allocate_tensors()
+  num_correct, total = 0, 0
+  for input_data in test_image_generator():
+    output = run_eval(interpreter, input_data[0])
+    total += 1
+    if output == input_data[1]:
+      num_correct += 1
+    if total % 500 == 0:
+      print('Accuracy after %i images: %f' %
+            (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+  tf.logging.set_verbosity(tf.logging.INFO)
+  tf.app.run(main)
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
new file mode 100644
index 0000000..a96e2c4
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -0,0 +1,702 @@
+{
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "6Y8E0lw5eYWm"
+      },
+      "source": [
+        "# Post Training Quantization"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "CIGrZZPTZVeO"
+      },
+      "source": [
+        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
+        "  \u003ctd\u003e\n",
+        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+        "  \u003c/td\u003e\n",
+        "  \u003ctd\u003e\n",
+        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
+        "  \u003c/td\u003e\n",
+        "\u003c/table\u003e"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "BTC1rDAuei_1"
+      },
+      "source": [
+        "## Overview\n",
+        "\n",
+        "[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) now supports\n",
+        "converting weights to 8 bit precision as part of model conversion from\n",
+        "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
+        "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
+        "fly quantization and dequantization of activations to allow for:\n",
+        "\n",
+        "1.  Using quantized kernels for faster implementation when available.\n",
+        "\n",
+        "2.  Mixing of floating-point kernels with quantized kernels for different parts\n",
+        "    of the graph.\n",
+        "\n",
+        "Note that the activations are always stored in floating point. For ops that\n",
+        "support quantized kernels, the activations are quantized to 8 bits of precision\n",
+        "dynamically prior to processing and are de-quantized to float precision after\n",
+        "processing. Depending on the model being converted, this can give a speedup over\n",
+        "pure floating point computation.\n",
+        "\n",
+        "In contrast to\n",
+        "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n",
+        ", the weights are quantized post training and the activations are quantized dynamically \n",
+        "at inference in this method.\n",
+        "Therefore, the model weights are not retrained to compensate for quantization\n",
+        "induced errors. It is important to check the accuracy of the quantized model to\n",
+        "ensure that the degradation is acceptable.\n",
+        "\n",
+        "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n",
+        "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+        "with weight quantization. We finally check the\n",
+        "accuracy of the converted model and compare it to the original saved model. We\n",
+        "run the training script mnist.py from\n",
+        "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2XsEP17Zelz9"
+      },
+      "source": [
+        "## Building an MNIST model"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "dDqqUIZjZjac"
+      },
+      "source": [
+        "### Setup"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "gyqAw1M9lyab"
+      },
+      "outputs": [],
+      "source": [
+        "! pip uninstall -y tensorflow\n",
+        "! pip install -U tf-nightly"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "WsN6s5L1ieNl"
+      },
+      "outputs": [],
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "00U0taBoe-w7"
+      },
+      "outputs": [],
+      "source": [
+        "! git clone --depth 1 https://github.com/tensorflow/models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "4XZPtSh-fUOc"
+      },
+      "outputs": [],
+      "source": [
+        "import sys\n",
+        "import os\n",
+        "\n",
+        "if sys.version_info.major \u003e= 3:\n",
+        "    import pathlib\n",
+        "else:\n",
+        "    import pathlib2 as pathlib\n",
+        "\n",
+        "# Add `models` to the python path.\n",
+        "models_path = os.path.join(os.getcwd(), \"models\")\n",
+        "sys.path.append(models_path)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "eQ6Q0qqKZogR"
+      },
+      "source": [
+        "### Train and export the model"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "eMsw_6HujaqM"
+      },
+      "outputs": [],
+      "source": [
+        "saved_models_root = \"/tmp/mnist_saved_model\""
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "hWSAjQWagIHl"
+      },
+      "outputs": [],
+      "source": [
+        "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+        "# Note: channels_last is required here or the conversion may fail. \n",
+        "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "5NMaNZQCkW9X"
+      },
+      "source": [
+        "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
+        "\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "xl8_fzVAZwOh"
+      },
+      "source": [
+        "### Convert to a TFLite model\n",
+        "\n",
+        "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "Xp5oClaZkbtn"
+      },
+      "outputs": [],
+      "source": [
+        "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+        "saved_model_dir"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "AT8BgkKmljOy"
+      },
+      "source": [
+        "Using the python `TocoConverter`, the saved model can be converted into a TFLite model.\n",
+        "\n",
+        "First load the model using the `TocoConverter`:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "_i8B2nDZmAgQ"
+      },
+      "outputs": [],
+      "source": [
+        "import tensorflow as tf\n",
+        "tf.enable_eager_execution()\n",
+        "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n",
+        "tflite_model = converter.convert()"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "F2o2ZfF0aiCx"
+      },
+      "source": [
+        "Write it out to a tflite file:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "vptWZq2xnclo"
+      },
+      "outputs": [],
+      "source": [
+        "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+        "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "Ie9pQaQrn5ue"
+      },
+      "outputs": [],
+      "source": [
+        "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+        "tflite_model_file.write_bytes(tflite_model)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "7BONhYtYocQY"
+      },
+      "source": [
+        "To quantize the model on export, set the `post_training_quantize` flag:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "g8PUvLWDlmmz"
+      },
+      "outputs": [],
+      "source": [
+        "# Note: If you don't have a recent tf-nightly installed, the\n",
+        "# \"post_training_quantize\" line will have no effect.\n",
+        "tf.logging.set_verbosity(tf.logging.INFO)\n",
+        "converter.post_training_quantize = True\n",
+        "tflite_quant_model = converter.convert()\n",
+        "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+        "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "PhMmUTl4sbkz"
+      },
+      "source": [
+        "Note how the resulting file, with `post_training_quantize` set, is approximately `1/4` the size."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "JExfcfLDscu4"
+      },
+      "outputs": [],
+      "source": [
+        "!ls -lh {tflite_models_dir}"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L8lQHMp_asCq"
+      },
+      "source": [
+        "## Run the TFLite models"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-5l6-ciItvX6"
+      },
+      "source": [
+        "We can run the TensorFlow Lite model using the python TensorFlow Lite\n",
+        "Interpreter. \n",
+        "\n",
+        "### load the test data\n",
+        "\n",
+        "First let's load the mnist test data to feed to it:"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "eTIuU07NuKFL"
+      },
+      "outputs": [],
+      "source": [
+        "import numpy as np\n",
+        "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+        "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n",
+        "\n",
+        "# Note: If you change the batch size, then use \n",
+        "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n",
+        "# the interpreter.\n",
+        "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Ap_jE7QRvhPf"
+      },
+      "source": [
+        "### Load the model into an interpreter"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "Jn16Rc23zTss"
+      },
+      "outputs": [],
+      "source": [
+        "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n",
+        "interpreter.allocate_tensors()\n",
+        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
+        "output_index = interpreter.get_output_details()[0][\"index\"]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "J8Pztk1mvNVL"
+      },
+      "outputs": [],
+      "source": [
+        "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+        "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "Afl6yGvWyqAr"
+      },
+      "outputs": [],
+      "source": [
+        "interpreter_quant.allocate_tensors()\n",
+        "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
+        "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "2opUt_JTdyEu"
+      },
+      "source": [
+        "### Test the model on one image"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "AKslvo2kwWac"
+      },
+      "outputs": [],
+      "source": [
+        "for img, label in mnist_ds.take(1):\n",
+        "  break\n",
+        "\n",
+        "interpreter.set_tensor(input_index, img)\n",
+        "interpreter.invoke()\n",
+        "predictions = interpreter.get_tensor(output_index)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "XZClM2vo3_bm"
+      },
+      "outputs": [],
+      "source": [
+        "import matplotlib.pylab as plt\n",
+        "\n",
+        "plt.imshow(img[0])\n",
+        "template = \"True:{true}, predicted:{predict}\"\n",
+        "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+        "                              predict=str(predictions[0,0])))\n",
+        "plt.grid(False)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "LwN7uIdCd8Gw"
+      },
+      "source": [
+        "### Evaluate the models"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "05aeAuWjvjPx"
+      },
+      "outputs": [],
+      "source": [
+        "def eval_model(interpreter, mnist_ds):\n",
+        "  total_seen = 0\n",
+        "  num_correct = 0\n",
+        "\n",
+        "  for img, label in mnist_ds:\n",
+        "    total_seen += 1\n",
+        "    interpreter.set_tensor(input_index, img)\n",
+        "    interpreter.invoke()\n",
+        "    predictions = interpreter.get_tensor(output_index)\n",
+        "    if predictions == label.numpy():\n",
+        "      num_correct += 1\n",
+        "\n",
+        "    if total_seen % 500 == 0:\n",
+        "        print(\"Accuracy after %i images: %f\" %\n",
+        "              (total_seen, float(num_correct) / float(total_seen)))\n",
+        "\n",
+        "  return float(num_correct) / float(total_seen)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "DqXBnDfJ7qxL"
+      },
+      "outputs": [],
+      "source": [
+        "print(eval_model(interpreter_quant, mnist_ds))"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "Km3cY9ry8ZlG"
+      },
+      "source": [
+        "We can repeat the evaluation on the weight quantized model to obtain:\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "-9cnwiPp6EGm"
+      },
+      "outputs": [],
+      "source": [
+        "print(eval_model(interpreter_quant, mnist_ds))\n"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "L7lfxkor8pgv"
+      },
+      "source": [
+        "\n",
+        "In this example, we have compressed model with no difference in the accuracy."
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "M0o1FtmWeKZm"
+      },
+      "source": [
+        "\n",
+        "\n",
+        "## Optimizing an existing model\n",
+        "\n",
+        "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
+        "  Pre-trained frozen graph for resnet-v2-101 is available at the\n",
+        "  [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n",
+        "\n",
+        "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "v5p5VcNPjILQ"
+      },
+      "outputs": [],
+      "source": [
+        "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
+        "archive_path = pathlib.Path(archive_path)"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "-sxnXQuC4ThD"
+      },
+      "source": [
+        "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "g_Q_OMEJ4LIc"
+      },
+      "outputs": [],
+      "source": [
+        "! cat {archive_path}/resnet_v2_101_299_info.txt"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "ujCAFhqm-C6H"
+      },
+      "outputs": [],
+      "source": [
+        "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
+        "input_arrays = [\"input\"] \n",
+        "output_arrays = [\"output\"]\n",
+        "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n",
+        "  str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
+        "converter.post_training_quantize = True\n",
+        "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
+        "resnet_tflite_file.write_bytes(converter.convert())\n"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 0,
+      "metadata": {
+        "colab": {},
+        "colab_type": "code",
+        "id": "vhOjeg1x9Knp"
+      },
+      "outputs": [],
+      "source": [
+        "archive_dir = str(archive_path.parent)\n",
+        "!ls -lh {archive_dir}"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "qqHLaqFMCjRZ"
+      },
+      "source": [
+        "\n",
+        "The model size reduces from 171 MB to 43 MB.\n",
+        "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n",
+        "\n",
+        "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
+      ]
+    }
+  ],
+  "metadata": {
+    "colab": {
+      "collapsed_sections": [],
+      "name": "post-training-quant.ipynb",
+      "private_outputs": true,
+      "provenance": [],
+      "toc_visible": true,
+      "version": "0.3.2"
+    },
+    "kernelspec": {
+      "display_name": "Python 2",
+      "name": "python2"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f5b208a..6d81f84 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -22,7 +22,7 @@
 #define TENSORFLOW_CONTRIB_LITE_UTIL_H_
 
 #include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 
 namespace tflite {
 
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 32bf917..c5c1709 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -17,7 +17,7 @@
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
 #include "tensorflow/contrib/lite/util.h"
 
 namespace tflite {
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0a54bb1..89b538d 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -44,7 +44,7 @@
 class HashTableOpTest(test.TestCase):
 
   def testHashTable(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -68,7 +68,7 @@
       self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
 
   def testHashTableFindHighRank(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -86,7 +86,7 @@
       self.assertAllEqual([[0, 1], [-1, -1]], result)
 
   def testHashTableInitWithPythonArrays(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = ["brain", "salad", "surgery"]
       values = [0, 1, 2]
@@ -105,7 +105,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testHashTableInitWithNumPyArrays(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
       values = np.array([0, 1, 2], dtype=np.int64)
@@ -122,7 +122,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testMultipleHashTables(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@
       self.assertAllEqual([0, 1, -1], out3)
 
   def testHashTableWithTensorDefault(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant(-1, dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -165,7 +165,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testHashTableWithSparseTensorInput(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_val = constant_op.constant(-1, dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -188,7 +188,7 @@
       self.assertAllEqual(sp_shape, out_shape)
 
   def testSignatureMismatch(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -210,7 +210,7 @@
             lookup.KeyValueTensorInitializer(keys, values), "UNK")
 
   def testDTypes(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       with self.assertRaises(TypeError):
         lookup.HashTable(
@@ -218,7 +218,7 @@
                                              dtypes.int64), default_val)
 
   def testNotInitialized(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       table = lookup.HashTable(
           lookup.KeyValueTensorInitializer(
@@ -232,7 +232,7 @@
         output.eval()
 
   def testInitializeTwice(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -244,7 +244,7 @@
         table.init.run()
 
   def testInitializationWithInvalidDimensions(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -283,7 +283,7 @@
       self.assertAllEqual(3, table.size().eval())
 
   def testHashTableInt32String(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = "n/a"
       keys = constant_op.constant([0, 1, 2], dtypes.int32)
       values = constant_op.constant(["brain", "salad", "surgery"])
@@ -301,7 +301,7 @@
 class MutableHashTableOpTest(test.TestCase):
 
   def testMutableHashTable(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -470,7 +470,7 @@
       self.assertAllEqual([b"-", b"a", b"b"], output.eval())
 
   def testMutableHashTableOfTensors(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant([-1, -1], dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -500,7 +500,7 @@
       self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
 
   def testMutableHashTableExportInsert(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant([-1, -1], dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -531,7 +531,7 @@
       self.assertAllEqual(expected_output, output2.eval())
 
   def testMutableHashTableOfTensorsInvalidShape(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant([-1, -1], dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
@@ -563,7 +563,7 @@
       self.assertAllEqual(3, table.size().eval())
 
   def testMutableHashTableInvalidDefaultValue(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant([[-1, -1]], dtypes.int64)
       table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
                                       default_val)
@@ -571,7 +571,7 @@
         self.assertAllEqual(0, table.size().eval())
 
   def testMutableHashTableDuplicateInsert(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
       values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@@ -589,7 +589,7 @@
       self.assertAllEqual([3, 1, -1], result)
 
   def testMutableHashTableFindHighRank(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -608,7 +608,7 @@
       self.assertAllEqual([[0, 1], [-1, -1]], result)
 
   def testMutableHashTableInsertHighRank(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
       values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
@@ -625,7 +625,7 @@
       self.assertAllEqual([0, 1, 3, -1], result)
 
   def testMutableHashTableOfTensorsFindHighRank(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
@@ -646,7 +646,7 @@
           [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
 
   def testMultipleMutableHashTables(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -676,7 +676,7 @@
       self.assertAllEqual([0, 1, -1], out3)
 
   def testMutableHashTableWithTensorDefault(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant(-1, dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -693,7 +693,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testSignatureMismatch(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -734,7 +734,7 @@
         lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
 
   def testMutableHashTableStringFloat(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1.5
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
@@ -752,7 +752,7 @@
       self.assertAllClose([0, 1.1, default_val], result)
 
   def testMutableHashTableIntFloat(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1.0
       keys = constant_op.constant([3, 7, 0], dtypes.int64)
       values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
@@ -770,7 +770,7 @@
       self.assertAllClose([-1.2, 9.9, default_val], result)
 
   def testMutableHashTableInt64String(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = "n/a"
       keys = constant_op.constant([0, 1, 2], dtypes.int64)
       values = constant_op.constant(["brain", "salad", "surgery"])
@@ -791,7 +791,7 @@
 class MutableDenseHashTableOpTest(test.TestCase):
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([0, 1, 2], dtypes.int64)
       table = lookup.MutableDenseHashTable(
@@ -809,7 +809,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testBasicBool(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([True, True, True], dtypes.bool)
       table = lookup.MutableDenseHashTable(
@@ -827,7 +827,7 @@
       self.assertAllEqual([True, True, False], result)
 
   def testLookupUnknownShape(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([0, 1, 2], dtypes.int64)
       table = lookup.MutableDenseHashTable(
@@ -843,7 +843,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testMapStringToFloat(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant(["a", "b", "c"], dtypes.string)
       values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
       default_value = constant_op.constant(-1.5, dtypes.float32)
@@ -866,7 +866,7 @@
 
   def testMapInt64ToFloat(self):
     for float_dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         keys = constant_op.constant([11, 12, 13], dtypes.int64)
         values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
         default_value = constant_op.constant(-1.5, float_dtype)
@@ -885,7 +885,7 @@
         self.assertAllClose([0, 1.1, -1.5], result)
 
   def testVectorValues(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
                                     dtypes.int64)
@@ -918,7 +918,7 @@
                           result)
 
   def testVectorKeys(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
       values = constant_op.constant([10, 11, 12], dtypes.int64)
       empty_key = constant_op.constant([0, 3], dtypes.int64)
@@ -949,7 +949,7 @@
       self.assertAllEqual([10, 11, -1], result)
 
   def testResize(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([0, 1, 2], dtypes.int64)
       table = lookup.MutableDenseHashTable(
@@ -977,7 +977,7 @@
       self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
 
   def testExport(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 12, 13], dtypes.int64)
       values = constant_op.constant([1, 2, 3], dtypes.int64)
       table = lookup.MutableDenseHashTable(
@@ -1238,7 +1238,7 @@
       self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
 
   def testReprobe(self):
-    with self.test_session():
+    with self.cached_session():
       # Insert 6 keys into a table with 8 buckets.
       # The values are chosen to make sure collisions occur when using GCC STL
       keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
@@ -1263,7 +1263,7 @@
       self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
 
   def testCustomEmptyKey(self):
-    with self.test_session():
+    with self.cached_session():
       keys = constant_op.constant([11, 0, 13], dtypes.int64)
       values = constant_op.constant([0, 1, 2], dtypes.int64)
       table = lookup.MutableDenseHashTable(
@@ -1281,7 +1281,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testErrors(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup.MutableDenseHashTable(
           dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
 
@@ -1328,7 +1328,7 @@
 
   def test_string_index_table_from_file(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1339,7 +1339,7 @@
 
   def test_string_index_table_from_file_tensor_filename(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       vocabulary_file = constant_op.constant(vocabulary_file)
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -1353,7 +1353,7 @@
 
   def test_string_index_table_from_file_placeholder_filename(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -1370,7 +1370,7 @@
   def test_int32_index_table_from_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab2.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1,
           key_dtype=dtypes.int32)
@@ -1384,7 +1384,7 @@
   def test_int64_index_table_from_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab3.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1,
           key_dtype=dtypes.int64)
@@ -1398,7 +1398,7 @@
   def test_index_table_from_file_with_default_value(self):
     default_value = -42
     vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, default_value=default_value)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1409,7 +1409,7 @@
 
   def test_index_table_from_file_with_oov_buckets(self):
     vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1000)
       ids = table.lookup(
@@ -1439,7 +1439,7 @@
 
   def test_index_table_from_file_with_vocab_size_too_small(self):
     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=2)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1451,7 +1451,7 @@
 
   def test_index_table_from_file_with_vocab_size_too_large(self):
     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=4)
       self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -1466,7 +1466,7 @@
         vocabulary_file=vocabulary_file,
         vocab_size=0)
 
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=3)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1478,7 +1478,7 @@
 
   def test_index_table_from_file_with_invalid_hashers(self):
     vocabulary_file = self._createVocabFile("invalid_hasher.txt")
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         lookup.index_table_from_file(
             vocabulary_file=vocabulary_file,
@@ -1499,21 +1499,21 @@
 class KeyValueTensorInitializerTest(test.TestCase):
 
   def test_string(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup.KeyValueTensorInitializer(
           ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
       table = lookup.HashTable(init, default_value=-1)
       table.init.run()
 
   def test_int64(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup.KeyValueTensorInitializer(
           (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
       table = lookup.HashTable(init, default_value=-1)
       table.init.run()
 
   def test_int32(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup.KeyValueTensorInitializer(
           (42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
       table = lookup.HashTable(init, default_value=-1)
@@ -1542,7 +1542,7 @@
     self.assertAllEqual((1, 2, 3), self.evaluate(ids))
 
   def test_int32_index_table_from_tensor_with_tensor_init(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_tensor(
           mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
       ids = table.lookup(
@@ -1553,7 +1553,7 @@
       self.assertAllEqual((1, 2, 3), ids.eval())
 
   def test_int64_index_table_from_tensor_with_tensor_init(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_tensor(
           mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
       ids = table.lookup(
@@ -1565,7 +1565,7 @@
 
   def test_index_table_from_tensor_with_default_value(self):
     default_value = -42
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_tensor(
           mapping=["brain", "salad", "surgery"], default_value=default_value)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1575,12 +1575,12 @@
       self.assertAllEqual((1, 2, default_value), ids.eval())
 
   def test_index_table_from_tensor_missing_mapping(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
         lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
 
   def test_index_table_from_tensor_empty_mapping(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_table_from_tensor(
           mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -1590,7 +1590,7 @@
         lookup_ops.tables_initializer().run()
 
   def test_index_table_from_tensor_with_invalid_hashers(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         lookup.index_table_from_tensor(
             mapping=["brain", "salad", "surgery"],
@@ -1609,7 +1609,7 @@
 class StringToIndexTest(test.TestCase):
 
   def test_string_to_index(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       feats = constant_op.constant(["salad", "surgery", "tarkus"])
       indices = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1620,7 +1620,7 @@
       self.assertAllEqual((1, 2, -1), indices.eval())
 
   def test_duplicate_entries(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["hello", "hello"])
       feats = constant_op.constant(["hello", "hola"])
       _ = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1630,7 +1630,7 @@
 
   def test_string_to_index_with_default_value(self):
     default_value = -42
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       feats = constant_op.constant(["salad", "surgery", "tarkus"])
       indices = lookup.string_to_index(
@@ -1651,7 +1651,7 @@
 
   def test_index_to_string_table(self):
     vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file)
       features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
@@ -1663,7 +1663,7 @@
   def test_index_to_string_table_with_default_value(self):
     default_value = b"NONE"
     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, default_value=default_value)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1675,7 +1675,7 @@
   def test_index_to_string_table_with_vocab_size_too_small(self):
     default_value = b"NONE"
     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file,
           vocab_size=2,
@@ -1688,7 +1688,7 @@
 
   def test_index_to_string_table_with_vocab_size_too_large(self):
     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=4)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1700,7 +1700,7 @@
 
   def test_index_to_string_table_with_vocab_size(self):
     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=3)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1713,7 +1713,7 @@
 class IndexToStringTableFromTensorTest(test.TestCase):
 
   def test_index_to_string_table_from_tensor(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       table = lookup.index_to_string_table_from_tensor(
           mapping=mapping_strings)
@@ -1727,7 +1727,7 @@
                           features.eval())
 
   def test_duplicate_entries(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["hello", "hello"])
       table = lookup.index_to_string_table_from_tensor(
           mapping=mapping_strings)
@@ -1738,7 +1738,7 @@
 
   def test_index_to_string_with_default_value(self):
     default_value = b"NONE"
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       table = lookup.index_to_string_table_from_tensor(
           mapping=mapping_strings, default_value=default_value)
@@ -1754,7 +1754,7 @@
 class IndexToStringTest(test.TestCase):
 
   def test_index_to_string(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
       feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1766,7 +1766,7 @@
                           feats.eval())
 
   def test_duplicate_entries(self):
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["hello", "hello"])
       indices = constant_op.constant([0, 1, 4], dtypes.int64)
       feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1778,7 +1778,7 @@
 
   def test_index_to_string_with_default_value(self):
     default_value = b"NONE"
-    with self.test_session():
+    with self.cached_session():
       mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
       indices = constant_op.constant([1, 2, 4], dtypes.int64)
       feats = lookup.index_to_string(
@@ -1818,7 +1818,7 @@
     vocabulary_file = self._createVocabFile(
         "one_column_int64.txt", values=("42", "1", "-1000"))
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       table = lookup.HashTable(
           lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
@@ -1837,7 +1837,7 @@
   def testInitializeIndexTable(self):
     vocabulary_file = self._createVocabFile("one_column_2.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       key_index = lookup.TextFileIndex.LINE_NUMBER
       value_index = lookup.TextFileIndex.WHOLE_LINE
@@ -1858,7 +1858,7 @@
     with open(vocabulary_file, "w") as f:
       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 1
       value_index = 2
@@ -1880,7 +1880,7 @@
     with open(vocabulary_file, "w") as f:
       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 2
       value_index = 1
@@ -1894,7 +1894,7 @@
   def testInvalidDataType(self):
     vocabulary_file = self._createVocabFile("one_column_3.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       key_index = lookup.TextFileIndex.WHOLE_LINE
       value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1907,7 +1907,7 @@
 
   def testInvalidIndex(self):
     vocabulary_file = self._createVocabFile("one_column_4.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 1  # second column of the line
       value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1922,7 +1922,7 @@
   def testInitializeSameTableWithMultipleNodes(self):
     vocabulary_file = self._createVocabFile("one_column_5.txt")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shared_name = "shared-one-columm"
       default_value = -1
       table1 = lookup.HashTable(
@@ -1961,7 +1961,7 @@
       self.assertAllEqual([0, 1, -1], out3)
 
   def testInitializeTableWithNoFilename(self):
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       with self.assertRaises(ValueError):
         lookup.HashTable(
@@ -1971,7 +1971,7 @@
             default_value)
 
   def testInitializeWithVocabSize(self):
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -2022,7 +2022,7 @@
   def testFeedVocabularyName(self):
     vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       table = lookup.HashTable(
           lookup.TextFileInitializer("old_file.txt", dtypes.string,
@@ -2049,7 +2049,7 @@
   def testInvalidFilenames(self):
     vocabulary_file = self._createVocabFile("filename_shape.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
 
       # Invalid data type
@@ -2072,7 +2072,7 @@
 
   def testIdToStringTable(self):
     vocab_file = self._createVocabFile("feat_to_id_1.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       vocab_size = 3
       table = lookup.HashTable(
@@ -2090,7 +2090,7 @@
 
   def testStringToIdTable(self):
     vocab_file = self._createVocabFile("feat_to_id_2.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       table = lookup.HashTable(
@@ -2108,7 +2108,7 @@
   def testInt64ToIdTable(self):
     vocab_file = self._createVocabFile(
         "feat_to_id_3.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       table = lookup.HashTable(
@@ -2133,7 +2133,7 @@
 
   def testStringIdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_1.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -2154,7 +2154,7 @@
 
   def testInt32IdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -2176,7 +2176,7 @@
 
   def testInt64IdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -2196,7 +2196,7 @@
       self.assertEquals(vocab_size + oov_buckets, table.size().eval())
 
   def testStringIdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       oov_buckets = 5
 
       # Set a table that only uses hash buckets, for each input value returns
@@ -2217,7 +2217,7 @@
       self.assertEquals(oov_buckets, table.size().eval())
 
   def testInt32IdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       oov_buckets = 5
 
       # Set a table that only uses hash buckets, for each input value returns
@@ -2239,20 +2239,20 @@
       self.assertEquals(oov_buckets, table.size().eval())
 
   def testFloat64IdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
         lookup.IdTableWithHashBuckets(
             None, num_oov_buckets=5, key_dtype=dtypes.float64)
 
   def testBoolIdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
         lookup.IdTableWithHashBuckets(
             None, num_oov_buckets=5, key_dtype=dtypes.bool)
 
   def testIdTableWithHashBucketsWithMultipleInitializers(self):
     vocab_file = self._createVocabFile("feat_to_id_4.txt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_value = -1
       vocab_size = 3
       oov_buckets = 3
@@ -2294,7 +2294,7 @@
   def testIdTableWithHashBucketsInitializationAcrossSessions(self):
     vocab_file = self._createVocabFile("feat_to_id_5.txt")
     shared_name = "across-sessions"
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -2316,7 +2316,7 @@
       self.assertAllEqual([0, 1, 2, 3], out1.eval())
       self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -2340,7 +2340,7 @@
 
   def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
     vocab_file = self._createVocabFile("feat_to_id_6.txt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_value1 = -1
       vocab_size = 3
       oov_buckets = 0
@@ -2378,7 +2378,7 @@
     vocab_file = self._createVocabFile("feat_to_id_7.txt")
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -2407,7 +2407,7 @@
   def testInt32SparseTensor(self):
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -2436,7 +2436,7 @@
   def testInt64SparseTensor(self):
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -2464,7 +2464,7 @@
 
   def testIdTableWithHashBucketsWithInvalidHashers(self):
     vocab_file = self._createVocabFile("feat_to_id_4.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 2a442a8..c0aec09 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -43,68 +43,68 @@
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.absolute_difference(
             self._predictions, self._predictions, weights=None)
 
   def testAllCorrectNoLossWeight(self):
     loss = loss_ops.absolute_difference(self._predictions, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testNonZeroLoss(self):
     loss = loss_ops.absolute_difference(self._predictions, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5, loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithScalarTensorWeight(self):
     weights = 2.3
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 0.0], shape=[2,])
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.6, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.6, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeights(self):
     weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(16.6, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
     weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(6.0, loss.eval(), 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     weights = array_ops.zeros((2, 3))
     loss = loss_ops.absolute_difference(self._predictions, self._labels,
                                         weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
@@ -117,12 +117,12 @@
     labels = constant_op.constant([[1, 0, 0],
                                    [0, 1, 0],
                                    [0, 0, 1]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.softmax_cross_entropy(logits, labels, weights=None)
 
   def testAllCorrect(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0],
                                      [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
@@ -141,7 +141,7 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
 
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels)
       self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -154,7 +154,7 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
 
@@ -166,7 +166,7 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels,
                                             constant_op.constant(weights))
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -179,7 +179,7 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
     weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -191,7 +191,7 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
     weights = constant_op.constant([0, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
@@ -203,12 +203,12 @@
                                    [1, 0, 0],
                                    [0, 1, 0]])
     weights = constant_op.constant([1.2, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(12.0, loss.eval(), 3)
 
   def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -223,7 +223,7 @@
         loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval()
 
   def testSoftmaxLabelSmoothing(self):
-    with self.test_session():
+    with self.cached_session():
       # Softmax Cross Entropy Loss is:
       #   -\sum_i p_i \log q_i
       # where for a softmax activation
@@ -253,7 +253,7 @@
     weights = [2.3, 2.4, 2.5]
     weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
     loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, {weights_placeholder: weights})
       self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
 
@@ -268,7 +268,7 @@
     weights_placeholder = array_ops.placeholder(
         dtypes.float32, shape=[None, None])
     loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, {weights_placeholder: weights})
       self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
 
@@ -280,12 +280,12 @@
                                    [0.0, 10.0, 0.0],
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0], [1], [2]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None)
 
   def testAllCorrectInt32Labels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0],
                                      [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
@@ -295,7 +295,7 @@
       self.assertAlmostEqual(loss.eval(), 0.0, 3)
 
   def testAllCorrectInt64Labels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0],
                                      [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
@@ -305,7 +305,7 @@
       self.assertAlmostEqual(loss.eval(), 0.0, 3)
 
   def testAllCorrectNonColumnLabels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0],
                                      [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
@@ -320,7 +320,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
 
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -331,7 +331,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
 
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -342,7 +342,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([2, 0, 1])
 
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -353,7 +353,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
 
@@ -363,7 +363,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(
           logits, labels, constant_op.constant(weights))
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -374,7 +374,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -384,7 +384,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([[1.2], [3.4], [5.6]])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -394,7 +394,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([0, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
@@ -404,12 +404,12 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([1.2, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
       self.assertAlmostEqual(12.0, loss.eval(), 3)
 
   def testMeasurementSpecificWeightsRaisesException(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -422,7 +422,7 @@
 
   def testInconsistentWeightSizeRaisesException(self):
     """The weight tensor has incorrect number of elements."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -435,7 +435,7 @@
 
   def testInconsistentLabelSizeRaisesException(self):
     """The label tensor has incorrect number of elements."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -448,7 +448,7 @@
 
   def testInconsistentWeightShapeRaisesException(self):
     """The weight tensor has incorrect shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0, -100.0],
                                      [-100.0, -100.0, 100.0, -100.0],
@@ -462,7 +462,7 @@
 
   def testInconsistentLabelShapeRaisesException(self):
     """The label tensor has incorrect shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0, -100.0],
                                      [-100.0, -100.0, 100.0, -100.0],
@@ -484,7 +484,7 @@
         dtypes.float32, shape=[None])
     loss = loss_ops.sparse_softmax_cross_entropy(
         logits, labels, weights_placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, {weights_placeholder: weights})
       self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
 
@@ -498,7 +498,7 @@
         dtypes.float32, shape=[None, None])
     loss = loss_ops.sparse_softmax_cross_entropy(
         logits, labels, weights_placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, {weights_placeholder: weights})
       self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
 
@@ -506,7 +506,7 @@
 class SigmoidCrossEntropyLossTest(test.TestCase):
 
   def testAllCorrectSigmoid(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -522,7 +522,7 @@
 
     loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: np.ones((32, 1)),
@@ -537,7 +537,7 @@
 
     loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: np.ones((32, 2)),
@@ -546,7 +546,7 @@
       self.assertAlmostEqual(0.313, loss, 3)
 
   def testAllWrongSigmoid(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -558,7 +558,7 @@
       self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
 
   def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -582,11 +582,11 @@
     loss = loss_ops.sigmoid_cross_entropy(logits, labels)
     self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(loss.eval(), 0.0, 3)
 
   def testSigmoidLabelSmoothingCorrect(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0]])
       labels = constant_op.constant([[1, 0, 1]])
       # Sigmoid cross entropy loss is:
@@ -608,7 +608,7 @@
       self.assertAlmostEqual(loss.eval(), expected_value, 3)
 
   def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
-    with self.test_session():
+    with self.cached_session():
       label_smoothing = 0.1
       sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
       sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -641,33 +641,33 @@
     self._labels = constant_op.constant(labels)
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.log_loss(self._labels, self._labels, weights=None)
 
   def testAllCorrectNoLossWeight(self):
     loss = loss_ops.log_loss(self._labels, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testAllCorrectNoLossWeightWithPlaceholder(self):
     tf_predictions = array_ops.placeholder(
         dtypes.float32, shape=self._np_labels.shape)
     loss = loss_ops.log_loss(tf_predictions, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(
           0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
 
   def testNonZeroLoss(self):
     loss = loss_ops.log_loss(self._predictions, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = loss_ops.log_loss(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
@@ -675,7 +675,7 @@
     weights = 2.3
     loss = loss_ops.log_loss(self._predictions, self._labels,
                              constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
@@ -685,7 +685,7 @@
     weights = 2.3
     loss = loss_ops.log_loss(tf_predictions, self._labels,
                              constant_op.constant(weights))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss, 3)
@@ -695,7 +695,7 @@
     weights = 2.3
     loss = loss_ops.log_loss(tf_predictions, self._labels,
                              constant_op.constant(weights))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss, 3)
@@ -706,7 +706,7 @@
         self._expected_losses,
         np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
     loss = loss_ops.log_loss(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -715,7 +715,7 @@
                                   np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
                                       (2, 3)))
     loss = loss_ops.log_loss(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -724,12 +724,12 @@
                                   np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
                                       (2, 3)))
     loss = loss_ops.log_loss(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
 
   def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
     weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.log_loss(self._predictions, self._labels, weights)
 
@@ -742,7 +742,7 @@
         self._labels,
         constant_op.constant(
             weights, shape=(2, 3)))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
 
   def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -756,7 +756,7 @@
         constant_op.constant(
             weights, shape=(2, 3)))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
 
@@ -769,7 +769,7 @@
         self._labels,
         constant_op.constant(
             weights, shape=(2, 3)))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -780,35 +780,35 @@
     tf_weights = constant_op.constant(weights, shape=(2, 3))
     loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     tf_weights = array_ops.zeros(shape=(2, 3))
     loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
 class HingeLossTest(test.TestCase):
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[-1.0], [2.1]])
       labels = constant_op.constant([0.0, 1.0])
       with self.assertRaises(ValueError):
         _ = loss_ops.hinge_loss(logits, labels).eval()
 
   def testAllOutsideMargin(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
       labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
       loss = loss_ops.hinge_loss(logits, labels)
       self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
 
   def testSomeInsideMargin(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
       labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
       loss = loss_ops.hinge_loss(logits, labels)
@@ -817,7 +817,7 @@
       self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
 
   def testSomeMisclassified(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
       labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
       loss = loss_ops.hinge_loss(logits, labels)
@@ -834,62 +834,62 @@
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.mean_squared_error(
             self._predictions, self._predictions, weights=None)
 
   def testAllCorrectNoLossWeight(self):
     loss = loss_ops.mean_squared_error(self._predictions, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testNonZeroLoss(self):
     loss = loss_ops.mean_squared_error(self._predictions, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5, loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithScalarTensorWeight(self):
     weights = 2.3
     loss = loss_ops.mean_squared_error(self._predictions, self._labels,
                                        constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 3.4], shape=[2,])
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeights(self):
     weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
     weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(18.0, loss.eval(), 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     weights = array_ops.zeros((2, 3))
     loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
@@ -914,7 +914,7 @@
     self._expected_losses = np.divide(total, 9.0)
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.mean_pairwise_squared_error(
             predictions=constant_op.constant(self._labels),
@@ -925,14 +925,14 @@
     loss = loss_ops.mean_pairwise_squared_error(
         predictions=constant_op.constant(self._labels),
         labels=constant_op.constant(self._labels))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testNonZeroLoss(self):
     loss = loss_ops.mean_pairwise_squared_error(
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
 
   def testGradientWithZeroWeight(self):
@@ -954,7 +954,7 @@
 
       init_op = variables.global_variables_initializer()
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(init_op)
         for grad, _ in gradients_to_variables:
           np_grad = sess.run(grad)
@@ -966,7 +966,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         weights=weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * np.sum(self._expected_losses),
                              loss.eval(), 3)
 
@@ -976,7 +976,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * np.sum(self._expected_losses),
                              loss.eval(), 3)
 
@@ -986,7 +986,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0, loss.eval(), 3)
 
   def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
@@ -998,7 +998,7 @@
         predictions=tf_predictions,
         labels=tf_labels,
         weights=constant_op.constant(weights))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           tf_predictions: self._predictions,
@@ -1015,7 +1015,7 @@
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(
             weights, shape=[2]))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
 
   def testZeroLossWithOneDimBatchZeroWeights(self):
@@ -1025,7 +1025,7 @@
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(
             weights, shape=[2]))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
@@ -1041,7 +1041,7 @@
         weights=constant_op.constant(
             weights, shape=[2]))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           tf_predictions: self._predictions,
@@ -1056,7 +1056,7 @@
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(
             weights, shape=[2]))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testLossIsAssociativeAcrossBatchElements(self):
@@ -1087,7 +1087,7 @@
           predictions=array_ops.concat([predictions0, predictions1], 0),
           labels=array_ops.concat([labels0, labels1], 0))
 
-      with self.test_session() as session:
+      with self.cached_session() as session:
         loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
 
         self.assertTrue(loss0 > 0)
@@ -1115,7 +1115,7 @@
                                [0, 1, 0]]).reshape((3, 2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.cosine_distance(
             predictions=constant_op.constant(self._labels),
@@ -1128,7 +1128,7 @@
         predictions=constant_op.constant(self._labels),
         labels=constant_op.constant(self._labels),
         dim=2)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0, loss.eval(), 5)
 
   def testPartiallyCorrectWithIntegerValues(self):
@@ -1136,7 +1136,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         dim=2)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(1, loss.eval(), 5)
 
   def testPartiallyCorrectFloatingPointValues(self):
@@ -1154,7 +1154,7 @@
         labels, shape=(3, 1, 3), dtype=dtypes.float32)
     loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(1.0, loss.eval(), 5)
 
   def testSampleSpecificWeights(self):
@@ -1163,7 +1163,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=constant_op.constant([1, 0, 0]))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(1.0, loss.eval())
 
   def testMeasurementSpecificWeights(self):
@@ -1173,12 +1173,12 @@
         dim=2,
         weights=constant_op.constant(
             [1, 0, 0, 1, 1, 1], shape=(3, 2)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(3.0 / 4.0, loss.eval())
 
   def testValueErrorThrownWithShapelessPlaceholder(self):
     tf_predictions = array_ops.placeholder(dtypes.float32)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         loss_ops.cosine_distance(
             predictions=tf_predictions,
@@ -1196,7 +1196,7 @@
         dim=2,
         weights=constant_op.constant(
             [1, 0, 0, 1, 1, 1], shape=(3, 2)))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
       self.assertEqual(3.0 / 4.0, loss)
 
@@ -1206,7 +1206,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=array_ops.zeros((3,)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0, loss.eval())
 
   def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1215,7 +1215,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=array_ops.zeros((3, 2)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0, loss.eval())
 
 
@@ -1228,7 +1228,7 @@
     self.assertFalse(loss_ops.get_losses())
     loss = loss_ops.compute_weighted_loss(losses)
     self.assertTrue(loss_ops.get_losses())
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
       self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3)
 
@@ -1243,7 +1243,7 @@
     loss_ops.add_loss(math_ops.reduce_mean(losses))
     self.assertTrue(loss_ops.get_losses())
     total_loss = loss_ops.get_total_loss()
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
       self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3)
 
@@ -1254,7 +1254,7 @@
     self.assertFalse(loss_ops.get_losses())
     loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None)
     self.assertFalse(loss_ops.get_losses())
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
 
   def testNoCollectLosses(self):
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 22b11f1..9ea94c7 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -1,61 +1,61 @@
-tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
-tensorflow/tools/proto_text/gen_proto_text_functions.cc
 tensorflow/core/framework/resource_handle.cc
-tensorflow/core/platform/default/protobuf.cc
-tensorflow/core/platform/tracing.cc
-tensorflow/core/platform/tensor_coding.cc
-tensorflow/core/platform/protobuf_util.cc
-tensorflow/core/platform/posix/posix_file_system.cc
-tensorflow/core/platform/posix/port.cc
-tensorflow/core/platform/posix/error.cc
-tensorflow/core/platform/posix/env.cc
-tensorflow/core/platform/posix/load_library.cc
-tensorflow/core/platform/posix/env_time.cc
-tensorflow/core/platform/file_system.cc
-tensorflow/core/platform/file_system_helper.cc
-tensorflow/core/platform/env.cc
-tensorflow/core/platform/env_time.cc
-tensorflow/core/platform/setround.cc
-tensorflow/core/platform/denormal.cc
-tensorflow/core/platform/default/tracing.cc
-tensorflow/core/platform/default/mutex.cc
-tensorflow/core/platform/default/logging.cc
-tensorflow/core/platform/cpu_info.cc
-tensorflow/core/lib/wav/wav_io.cc
-tensorflow/core/lib/strings/stringprintf.cc
-tensorflow/core/lib/strings/strcat.cc
-tensorflow/core/lib/strings/str_util.cc
-tensorflow/core/lib/strings/scanner.cc
-tensorflow/core/lib/strings/proto_text_util.cc
-tensorflow/core/lib/strings/ordered_code.cc
-tensorflow/core/lib/strings/numbers.cc
-tensorflow/core/lib/random/weighted_picker.cc
-tensorflow/core/lib/random/simple_philox.cc
-tensorflow/core/lib/random/random.cc
-tensorflow/core/lib/random/distribution_sampler.cc
-tensorflow/core/lib/io/zlib_outputbuffer.cc
-tensorflow/core/lib/io/zlib_inputstream.cc
-tensorflow/core/lib/io/zlib_compression_options.cc
-tensorflow/core/lib/io/two_level_iterator.cc
-tensorflow/core/lib/io/table_builder.cc
-tensorflow/core/lib/io/table.cc
-tensorflow/core/lib/io/record_writer.cc
-tensorflow/core/lib/io/record_reader.cc
-tensorflow/core/lib/io/random_inputstream.cc
-tensorflow/core/lib/io/path.cc
-tensorflow/core/lib/io/iterator.cc
-tensorflow/core/lib/io/inputstream_interface.cc
-tensorflow/core/lib/io/inputbuffer.cc
-tensorflow/core/lib/io/format.cc
-tensorflow/core/lib/io/compression.cc
-tensorflow/core/lib/io/buffered_inputstream.cc
-tensorflow/core/lib/io/block_builder.cc
-tensorflow/core/lib/io/block.cc
-tensorflow/core/lib/histogram/histogram.cc
-tensorflow/core/lib/hash/hash.cc
+tensorflow/core/lib/core/arena.cc
+tensorflow/core/lib/core/coding.cc
+tensorflow/core/lib/core/status.cc
+tensorflow/core/lib/core/threadpool.cc
 tensorflow/core/lib/hash/crc32c.cc
 tensorflow/core/lib/hash/crc32c_accelerate.cc
-tensorflow/core/lib/core/threadpool.cc
-tensorflow/core/lib/core/status.cc
-tensorflow/core/lib/core/coding.cc
-tensorflow/core/lib/core/arena.cc
+tensorflow/core/lib/hash/hash.cc
+tensorflow/core/lib/histogram/histogram.cc
+tensorflow/core/lib/io/block.cc
+tensorflow/core/lib/io/block_builder.cc
+tensorflow/core/lib/io/buffered_inputstream.cc
+tensorflow/core/lib/io/compression.cc
+tensorflow/core/lib/io/format.cc
+tensorflow/core/lib/io/inputbuffer.cc
+tensorflow/core/lib/io/inputstream_interface.cc
+tensorflow/core/lib/io/iterator.cc
+tensorflow/core/lib/io/path.cc
+tensorflow/core/lib/io/random_inputstream.cc
+tensorflow/core/lib/io/record_reader.cc
+tensorflow/core/lib/io/record_writer.cc
+tensorflow/core/lib/io/table.cc
+tensorflow/core/lib/io/table_builder.cc
+tensorflow/core/lib/io/two_level_iterator.cc
+tensorflow/core/lib/io/zlib_compression_options.cc
+tensorflow/core/lib/io/zlib_inputstream.cc
+tensorflow/core/lib/io/zlib_outputbuffer.cc
+tensorflow/core/lib/random/distribution_sampler.cc
+tensorflow/core/lib/random/random.cc
+tensorflow/core/lib/random/simple_philox.cc
+tensorflow/core/lib/random/weighted_picker.cc
+tensorflow/core/lib/strings/numbers.cc
+tensorflow/core/lib/strings/ordered_code.cc
+tensorflow/core/lib/strings/proto_text_util.cc
+tensorflow/core/lib/strings/scanner.cc
+tensorflow/core/lib/strings/str_util.cc
+tensorflow/core/lib/strings/strcat.cc
+tensorflow/core/lib/strings/stringprintf.cc
+tensorflow/core/lib/wav/wav_io.cc
+tensorflow/core/platform/cpu_info.cc
+tensorflow/core/platform/default/logging.cc
+tensorflow/core/platform/default/mutex.cc
+tensorflow/core/platform/default/protobuf.cc
+tensorflow/core/platform/default/tracing.cc
+tensorflow/core/platform/denormal.cc
+tensorflow/core/platform/env.cc
+tensorflow/core/platform/env_time.cc
+tensorflow/core/platform/file_system.cc
+tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/posix/env.cc
+tensorflow/core/platform/posix/env_time.cc
+tensorflow/core/platform/posix/error.cc
+tensorflow/core/platform/posix/load_library.cc
+tensorflow/core/platform/posix/port.cc
+tensorflow/core/platform/posix/posix_file_system.cc
+tensorflow/core/platform/protobuf_util.cc
+tensorflow/core/platform/setround.cc
+tensorflow/core/platform/tensor_coding.cc
+tensorflow/core/platform/tracing.cc
+tensorflow/tools/proto_text/gen_proto_text_functions.cc
+tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 938c4a5..1d6d9a6 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -1,41 +1,42 @@
-tensorflow/core/util/test_log.pb.cc
-tensorflow/core/util/saved_tensor_slice.pb.cc
-tensorflow/core/util/memmapped_file_system.pb.cc
-tensorflow/core/util/event.pb.cc
-tensorflow/core/protobuf/tensorflow_server.pb.cc
-tensorflow/core/protobuf/saver.pb.cc
-tensorflow/core/protobuf/queue_runner.pb.cc
-tensorflow/core/protobuf/named_tensor.pb.cc
-tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/example/example.pb.cc
+tensorflow/core/example/feature.pb.cc
+tensorflow/core/framework/allocation_description.pb.cc
+tensorflow/core/framework/api_def.pb.cc
+tensorflow/core/framework/attr_value.pb.cc
+tensorflow/core/framework/cost_graph.pb.cc
+tensorflow/core/framework/device_attributes.pb.cc
+tensorflow/core/framework/function.pb.cc
+tensorflow/core/framework/graph.pb.cc
+tensorflow/core/framework/graph_transfer_info.pb.cc
+tensorflow/core/framework/kernel_def.pb.cc
+tensorflow/core/framework/log_memory.pb.cc
+tensorflow/core/framework/model.pb.cc
+tensorflow/core/framework/node_def.pb.cc
+tensorflow/core/framework/op_def.pb.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
+tensorflow/core/framework/resource_handle.pb.cc
+tensorflow/core/framework/step_stats.pb.cc
+tensorflow/core/framework/summary.pb.cc
+tensorflow/core/framework/tensor.pb.cc
+tensorflow/core/framework/tensor_description.pb.cc
+tensorflow/core/framework/tensor_shape.pb.cc
+tensorflow/core/framework/tensor_slice.pb.cc
+tensorflow/core/framework/types.pb.cc
+tensorflow/core/framework/variable.pb.cc
+tensorflow/core/framework/versions.pb.cc
+tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/lib/core/error_codes.pb.cc
 tensorflow/core/protobuf/cluster.pb.cc
 tensorflow/core/protobuf/config.pb.cc
-tensorflow/core/protobuf/rewriter_config.pb.cc
 tensorflow/core/protobuf/debug.pb.cc
 tensorflow/core/protobuf/device_properties.pb.cc
-tensorflow/core/lib/core/error_codes.pb.cc
-tensorflow/core/framework/versions.pb.cc
-tensorflow/core/framework/variable.pb.cc
-tensorflow/core/framework/types.pb.cc
-tensorflow/core/framework/tensor_slice.pb.cc
-tensorflow/core/framework/tensor_shape.pb.cc
-tensorflow/core/framework/tensor_description.pb.cc
-tensorflow/core/framework/tensor.pb.cc
-tensorflow/core/framework/summary.pb.cc
-tensorflow/core/framework/step_stats.pb.cc
-tensorflow/core/framework/resource_handle.pb.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
-tensorflow/core/framework/api_def.pb.cc
-tensorflow/core/framework/op_def.pb.cc
-tensorflow/core/framework/node_def.pb.cc
-tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/kernel_def.pb.cc
-tensorflow/core/framework/graph_transfer_info.pb.cc
-tensorflow/core/framework/graph.pb.cc
-tensorflow/core/framework/function.pb.cc
-tensorflow/core/framework/device_attributes.pb.cc
-tensorflow/core/framework/cost_graph.pb.cc
-tensorflow/core/framework/attr_value.pb.cc
-tensorflow/core/framework/allocation_description.pb.cc
-tensorflow/core/example/feature.pb.cc
-tensorflow/core/example/example.pb.cc
-tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/named_tensor.pb.cc
+tensorflow/core/protobuf/queue_runner.pb.cc
+tensorflow/core/protobuf/rewriter_config.pb.cc
+tensorflow/core/protobuf/saver.pb.cc
+tensorflow/core/protobuf/tensorflow_server.pb.cc
+tensorflow/core/util/event.pb.cc
+tensorflow/core/util/memmapped_file_system.pb.cc
+tensorflow/core/util/saved_tensor_slice.pb.cc
+tensorflow/core/util/test_log.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index aa91b2f..884461e 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -1,42 +1,44 @@
-tensorflow/core/util/test_log.pb.h
-tensorflow/core/util/saved_tensor_slice.pb.h
-tensorflow/core/util/memmapped_file_system.pb.h
-tensorflow/core/util/event.pb.h
-tensorflow/core/protobuf/tensorflow_server.pb.h
-tensorflow/core/protobuf/saver.pb.h
-tensorflow/core/protobuf/queue_runner.pb.h
-tensorflow/core/protobuf/named_tensor.pb.h
-tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/example/example.pb.h
+tensorflow/core/example/feature.pb.h
+tensorflow/core/framework/allocation_description.pb.h
+tensorflow/core/framework/api_def.pb.h
+tensorflow/core/framework/attr_value.pb.h
+tensorflow/core/framework/cost_graph.pb.h
+tensorflow/core/framework/device_attributes.pb.h
+tensorflow/core/framework/function.pb.h
+tensorflow/core/framework/graph.pb.h
+tensorflow/core/framework/graph_transfer_info.pb.h
+tensorflow/core/framework/kernel_def.pb.h
+tensorflow/core/framework/log_memory.pb.h
+tensorflow/core/framework/model.pb.h
+tensorflow/core/framework/node_def.pb.h
+tensorflow/core/framework/op_def.pb.h
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
+tensorflow/core/framework/resource_handle.pb.h
+tensorflow/core/framework/step_stats.pb.h
+tensorflow/core/framework/summary.pb.h
+tensorflow/core/framework/tensor.pb.h
+tensorflow/core/framework/tensor_description.pb.h
+tensorflow/core/framework/tensor_shape.pb.h
+tensorflow/core/framework/tensor_slice.pb.h
+tensorflow/core/framework/types.pb.h
+tensorflow/core/framework/variable.pb.h
+tensorflow/core/framework/versions.pb.h
+tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/lib/core/error_codes.pb.h
 tensorflow/core/protobuf/cluster.pb.h
 tensorflow/core/protobuf/config.pb.h
 tensorflow/core/protobuf/debug.pb.h
 tensorflow/core/protobuf/device_properties.pb.h
+tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/named_tensor.pb.h
+tensorflow/core/protobuf/queue_runner.pb.h
 tensorflow/core/protobuf/rewriter_config.pb.h
+tensorflow/core/protobuf/saver.pb.h
 tensorflow/core/protobuf/tensor_bundle.pb.h
-tensorflow/core/lib/core/error_codes.pb.h
-tensorflow/core/framework/versions.pb.h
-tensorflow/core/framework/variable.pb.h
-tensorflow/core/framework/types.pb.h
-tensorflow/core/framework/tensor_slice.pb.h
-tensorflow/core/framework/tensor_shape.pb.h
-tensorflow/core/framework/tensor_description.pb.h
-tensorflow/core/framework/tensor.pb.h
-tensorflow/core/framework/summary.pb.h
-tensorflow/core/framework/step_stats.pb.h
-tensorflow/core/framework/resource_handle.pb.h
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
-tensorflow/core/framework/api_def.pb.h
-tensorflow/core/framework/op_def.pb.h
-tensorflow/core/framework/node_def.pb.h
-tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/kernel_def.pb.h
-tensorflow/core/framework/graph_transfer_info.pb.h
-tensorflow/core/framework/graph.pb.h
-tensorflow/core/framework/function.pb.h
-tensorflow/core/framework/device_attributes.pb.h
-tensorflow/core/framework/cost_graph.pb.h
-tensorflow/core/framework/attr_value.pb.h
-tensorflow/core/framework/allocation_description.pb.h
-tensorflow/core/example/feature.pb.h
-tensorflow/core/example/example.pb.h
-tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/protobuf/tensorflow_server.pb.h
+tensorflow/core/util/event.pb.h
+tensorflow/core/util/memmapped_file_system.pb.h
+tensorflow/core/util/saved_tensor_slice.pb.h
+tensorflow/core/util/test_log.pb.h
+
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 66a3315..08de54b 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -4,218 +4,19 @@
 tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
 tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
 tensorflow/contrib/boosted_trees/ops/training_ops.cc
-tensorflow/core/kernels/xent_op.cc
-tensorflow/core/kernels/where_op.cc
-tensorflow/core/kernels/variable_ops.cc
-tensorflow/core/kernels/unpack_op.cc
-tensorflow/core/kernels/unique_op.cc
-tensorflow/core/kernels/transpose_op.cc
-tensorflow/core/kernels/transpose_functor_cpu.cc
-tensorflow/core/kernels/training_op_helpers.cc
-tensorflow/core/kernels/training_ops.cc
-tensorflow/core/kernels/topk_op.cc
-tensorflow/core/kernels/tile_functor_cpu.cc
-tensorflow/core/kernels/tile_ops.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
-tensorflow/core/kernels/tensor_array_ops.cc
-tensorflow/core/kernels/tensor_array.cc
-tensorflow/core/kernels/strided_slice_op_inst_7.cc
-tensorflow/core/kernels/strided_slice_op_inst_6.cc
-tensorflow/core/kernels/strided_slice_op_inst_5.cc
-tensorflow/core/kernels/strided_slice_op_inst_4.cc
-tensorflow/core/kernels/strided_slice_op_inst_3.cc
-tensorflow/core/kernels/strided_slice_op_inst_2.cc
-tensorflow/core/kernels/strided_slice_op_inst_1.cc
-tensorflow/core/kernels/strided_slice_op_inst_0.cc
-tensorflow/core/kernels/strided_slice_op.cc
-tensorflow/core/kernels/stack_ops.cc
-tensorflow/core/kernels/split_op.cc
-tensorflow/core/kernels/split_v_op.cc
-tensorflow/core/kernels/split_lib_cpu.cc
-tensorflow/core/kernels/spectrogram_op.cc
-tensorflow/core/kernels/spectrogram.cc
-tensorflow/core/kernels/sparse_to_dense_op.cc
-tensorflow/core/kernels/sparse_matmul_op.cc
-tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
-tensorflow/core/kernels/sparse_reshape_op.c
-tensorflow/core/kernels/segment_reduction_ops.cc
-tensorflow/core/kernels/softsign_op.cc
-tensorflow/core/kernels/softplus_op.cc
-tensorflow/core/kernels/softmax_op.cc
-tensorflow/core/kernels/slice_op_cpu_impl_1.cc
-tensorflow/core/kernels/slice_op_cpu_impl_2.cc
-tensorflow/core/kernels/slice_op_cpu_impl_3.cc
-tensorflow/core/kernels/slice_op_cpu_impl_4.cc
-tensorflow/core/kernels/slice_op_cpu_impl_5.cc
-tensorflow/core/kernels/slice_op_cpu_impl_6.cc
-tensorflow/core/kernels/slice_op_cpu_impl_7.cc
-tensorflow/core/kernels/slice_op.cc
-tensorflow/core/kernels/shape_ops.cc
-tensorflow/core/kernels/session_ops.cc
-tensorflow/core/kernels/sequence_ops.cc
-tensorflow/core/kernels/sendrecv_ops.cc
-tensorflow/core/kernels/scatter_op.cc
-tensorflow/core/kernels/scatter_functor.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/scatter_nd_op.cc
-tensorflow/core/kernels/save_restore_tensor.cc
-tensorflow/core/kernels/save_restore_v2_ops.cc
-tensorflow/core/kernels/save_op.cc
-tensorflow/core/kernels/string_join_op.cc
-tensorflow/core/kernels/reverse_sequence_op.cc
-tensorflow/core/kernels/reverse_op.cc
-tensorflow/core/kernels/restore_op.cc
-tensorflow/core/kernels/resize_nearest_neighbor_op.cc
-tensorflow/core/kernels/resize_bilinear_op.cc
-tensorflow/core/kernels/reshape_util.cc
-tensorflow/core/kernels/reshape_op.cc
-tensorflow/core/kernels/relu_op.cc
-tensorflow/core/kernels/reduction_ops_sum.cc
-tensorflow/core/kernels/reduction_ops_prod.cc
-tensorflow/core/kernels/reduction_ops_min.cc
-tensorflow/core/kernels/reduction_ops_mean.cc
-tensorflow/core/kernels/reduction_ops_max.cc
-tensorflow/core/kernels/reduction_ops_common.cc
-tensorflow/core/kernels/reduction_ops_any.cc
-tensorflow/core/kernels/reduction_ops_all.cc
-tensorflow/core/kernels/roll_op.cc
-tensorflow/core/kernels/queue_op.cc
-tensorflow/core/kernels/queue_ops.cc
-tensorflow/core/kernels/queue_base.cc
-tensorflow/core/kernels/pooling_ops_common.cc
-tensorflow/core/kernels/padding_fifo_queue_op.cc
-tensorflow/core/kernels/padding_fifo_queue.cc
-tensorflow/core/kernels/pad_op.cc
-tensorflow/core/kernels/pack_op.cc
-tensorflow/core/kernels/ops_util.cc
-tensorflow/core/kernels/one_hot_op.cc
-tensorflow/core/kernels/non_max_suppression_op.cc
-tensorflow/core/kernels/no_op.cc
-tensorflow/core/kernels/mirror_pad_op.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
-tensorflow/core/kernels/mfcc_op.cc
-tensorflow/core/kernels/mfcc_mel_filterbank.cc
-tensorflow/core/kernels/mfcc_dct.cc
-tensorflow/core/kernels/mfcc.cc
-tensorflow/core/kernels/maxpooling_op.cc
-tensorflow/core/kernels/matmul_op.cc
-tensorflow/core/kernels/lrn_op.cc
-tensorflow/core/kernels/logging_ops.cc
-tensorflow/core/kernels/initializable_lookup_table.c
-tensorflow/core/kernels/lookup_table_init_op.cc
-tensorflow/core/kernels/lookup_table_op.cc
-tensorflow/core/kernels/lookup_util.cc
-tensorflow/core/kernels/inplace_ops.cc
-tensorflow/core/kernels/in_topk_op.cc
-tensorflow/core/kernels/immutable_constant_op.cc
-tensorflow/core/kernels/identity_op.cc
-tensorflow/core/kernels/identity_n_op.cc
-tensorflow/core/kernels/gather_op.cc
-tensorflow/core/kernels/gather_functor.cc
-tensorflow/core/kernels/gather_nd_op.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/fused_batch_norm_op.cc
-tensorflow/core/kernels/function_ops.cc
-tensorflow/core/kernels/fill_functor.cc
-tensorflow/core/kernels/fifo_queue.cc
-tensorflow/core/kernels/fifo_queue_op.cc
-tensorflow/core/kernels/fake_quant_ops.cc
-tensorflow/core/kernels/example_parsing_ops.cc
-tensorflow/core/kernels/encode_wav_op.cc
-tensorflow/core/kernels/dynamic_stitch_op.cc
-tensorflow/core/kernels/dynamic_partition_op.cc
-tensorflow/core/kernels/decode_bmp_op.cc
-tensorflow/core/kernels/depthtospace_op.cc
-tensorflow/core/kernels/data_format_ops.cc
-tensorflow/core/kernels/spacetodepth_op.cc
-tensorflow/core/kernels/dense_update_functor.cc
-tensorflow/core/kernels/dense_update_ops.cc
-tensorflow/core/kernels/deep_conv2d.cc
-tensorflow/core/kernels/decode_wav_op.cc
-tensorflow/core/kernels/xsmm_conv2d.cc
-tensorflow/core/kernels/cwise_ops_common.cc
-tensorflow/core/kernels/cwise_op_tanh.cc
-tensorflow/core/kernels/cwise_op_pow.cc
-tensorflow/core/kernels/cwise_op_sub.cc
-tensorflow/core/kernels/cwise_op_squared_difference.cc
-tensorflow/core/kernels/cwise_op_square.cc
-tensorflow/core/kernels/cwise_op_sqrt.cc
-tensorflow/core/kernels/cwise_op_sigmoid.cc
-tensorflow/core/kernels/cwise_op_sign.cc
-tensorflow/core/kernels/cwise_op_select.cc
-tensorflow/core/kernels/cwise_op_round.cc
-tensorflow/core/kernels/cwise_op_rsqrt.cc
-tensorflow/core/kernels/cwise_op_reciprocal.cc
-tensorflow/core/kernels/cwise_op_neg.cc
-tensorflow/core/kernels/cwise_op_mul_2.cc
-tensorflow/core/kernels/cwise_op_mul_1.cc
-tensorflow/core/kernels/cwise_op_minimum.cc
-tensorflow/core/kernels/cwise_op_maximum.cc
-tensorflow/core/kernels/cwise_op_logical_not.cc
-tensorflow/core/kernels/cwise_op_logical_and.cc
-tensorflow/core/kernels/cwise_op_logical_or.cc
-tensorflow/core/kernels/cwise_op_log.cc
-tensorflow/core/kernels/cwise_op_less.cc
-tensorflow/core/kernels/cwise_op_less_equal.cc
-tensorflow/core/kernels/cwise_op_isnan.cc
-tensorflow/core/kernels/cwise_op_isfinite.cc
-tensorflow/core/kernels/cwise_op_invert.cc
-tensorflow/core/kernels/cwise_op_greater_equal.cc
-tensorflow/core/kernels/cwise_op_greater.cc
-tensorflow/core/kernels/cwise_op_floor_div.cc
-tensorflow/core/kernels/cwise_op_floor_mod.cc
-tensorflow/core/kernels/cwise_op_floor.cc
-tensorflow/core/kernels/cwise_op_exp.cc
-tensorflow/core/kernels/cwise_op_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_div.cc
-tensorflow/core/kernels/cwise_op_bitwise_xor.cc
-tensorflow/core/kernels/cwise_op_bitwise_or.cc
-tensorflow/core/kernels/cwise_op_bitwise_and.cc
-tensorflow/core/kernels/cwise_op_left_shift.cc
-tensorflow/core/kernels/cwise_op_right_shift.cc
-tensorflow/core/kernels/cwise_op_add_2.cc
-tensorflow/core/kernels/cwise_op_add_1.cc
-tensorflow/core/kernels/cwise_op_abs.cc
-tensorflow/core/kernels/ctc_decoder_ops.cc
-tensorflow/core/kernels/crop_and_resize_op.cc
-tensorflow/core/kernels/conv_ops_using_gemm.cc
-tensorflow/core/kernels/conv_ops_fused.cc
-tensorflow/core/kernels/conv_ops.cc
-tensorflow/core/kernels/conv_grad_filter_ops.cc
-tensorflow/core/kernels/conv_grad_input_ops.cc
-tensorflow/core/kernels/conv_grad_ops.cc
-tensorflow/core/kernels/control_flow_ops.cc
-tensorflow/core/kernels/constant_op.cc
-tensorflow/core/kernels/concat_op.cc
-tensorflow/core/kernels/concat_lib_cpu.cc
-tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/argmax_op.cc
+tensorflow/core/kernels/avgpooling_op.cc
+tensorflow/core/kernels/batch_matmul_op_real.cc
+tensorflow/core/kernels/batch_norm_op.cc
+tensorflow/core/kernels/batchtospace_op.cc
+tensorflow/core/kernels/bcast_ops.cc
+tensorflow/core/kernels/bias_op.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
 tensorflow/core/kernels/cast_op.cc
 tensorflow/core/kernels/cast_op_impl_bfloat.cc
 tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -232,20 +33,131 @@
 tensorflow/core/kernels/cast_op_impl_uint32.cc
 tensorflow/core/kernels/cast_op_impl_uint64.cc
 tensorflow/core/kernels/cast_op_impl_uint8.cc
-tensorflow/core/kernels/boosted_trees/prediction_ops.cc
-tensorflow/core/kernels/boosted_trees/resource_ops.cc
-tensorflow/core/kernels/boosted_trees/resources.cc
-tensorflow/core/kernels/boosted_trees/stats_ops.cc
-tensorflow/core/kernels/boosted_trees/training_ops.cc
-tensorflow/core/kernels/bias_op.cc
-tensorflow/core/kernels/bcast_ops.cc
-tensorflow/core/kernels/batch_norm_op.cc
-tensorflow/core/kernels/avgpooling_op.cc
-tensorflow/core/kernels/argmax_op.cc
-tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/concat_lib_cpu.cc
+tensorflow/core/kernels/concat_op.cc
+tensorflow/core/kernels/constant_op.cc
+tensorflow/core/kernels/control_flow_ops.cc
+tensorflow/core/kernels/conv_grad_filter_ops.cc
+tensorflow/core/kernels/conv_grad_input_ops.cc
+tensorflow/core/kernels/conv_grad_ops.cc
+tensorflow/core/kernels/conv_ops.cc
+tensorflow/core/kernels/conv_ops_fused.cc
+tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/crop_and_resize_op.cc
+tensorflow/core/kernels/ctc_decoder_ops.cc
+tensorflow/core/kernels/cwise_op_abs.cc
+tensorflow/core/kernels/cwise_op_add_1.cc
+tensorflow/core/kernels/cwise_op_add_2.cc
+tensorflow/core/kernels/cwise_op_bitwise_and.cc
+tensorflow/core/kernels/cwise_op_bitwise_or.cc
+tensorflow/core/kernels/cwise_op_bitwise_xor.cc
+tensorflow/core/kernels/cwise_op_div.cc
+tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_exp.cc
+tensorflow/core/kernels/cwise_op_floor.cc
+tensorflow/core/kernels/cwise_op_floor_div.cc
+tensorflow/core/kernels/cwise_op_floor_mod.cc
+tensorflow/core/kernels/cwise_op_greater.cc
+tensorflow/core/kernels/cwise_op_greater_equal.cc
+tensorflow/core/kernels/cwise_op_invert.cc
+tensorflow/core/kernels/cwise_op_isfinite.cc
+tensorflow/core/kernels/cwise_op_isnan.cc
+tensorflow/core/kernels/cwise_op_left_shift.cc
+tensorflow/core/kernels/cwise_op_less.cc
+tensorflow/core/kernels/cwise_op_less_equal.cc
+tensorflow/core/kernels/cwise_op_log.cc
+tensorflow/core/kernels/cwise_op_logical_and.cc
+tensorflow/core/kernels/cwise_op_logical_not.cc
+tensorflow/core/kernels/cwise_op_logical_or.cc
+tensorflow/core/kernels/cwise_op_maximum.cc
+tensorflow/core/kernels/cwise_op_minimum.cc
+tensorflow/core/kernels/cwise_op_mul_1.cc
+tensorflow/core/kernels/cwise_op_mul_2.cc
+tensorflow/core/kernels/cwise_op_neg.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_pow.cc
+tensorflow/core/kernels/cwise_op_reciprocal.cc
+tensorflow/core/kernels/cwise_op_right_shift.cc
+tensorflow/core/kernels/cwise_op_round.cc
+tensorflow/core/kernels/cwise_op_rsqrt.cc
+tensorflow/core/kernels/cwise_op_select.cc
+tensorflow/core/kernels/cwise_op_sigmoid.cc
+tensorflow/core/kernels/cwise_op_sign.cc
+tensorflow/core/kernels/cwise_op_sqrt.cc
+tensorflow/core/kernels/cwise_op_square.cc
+tensorflow/core/kernels/cwise_op_squared_difference.cc
+tensorflow/core/kernels/cwise_op_sub.cc
+tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_ops_common.cc
+tensorflow/core/kernels/data_format_ops.cc
+tensorflow/core/kernels/decode_bmp_op.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/decode_wav_op.cc
+tensorflow/core/kernels/deep_conv2d.cc
+tensorflow/core/kernels/dense_update_functor.cc
+tensorflow/core/kernels/dense_update_ops.cc
+tensorflow/core/kernels/depthtospace_op.cc
 tensorflow/core/kernels/depthwise_conv_op.cc
 tensorflow/core/kernels/dequantize_op.cc
+tensorflow/core/kernels/dynamic_partition_op.cc
+tensorflow/core/kernels/dynamic_stitch_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/kernels/encode_wav_op.cc
+tensorflow/core/kernels/example_parsing_ops.cc
+tensorflow/core/kernels/fake_quant_ops.cc
+tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
+tensorflow/core/kernels/fill_functor.cc
+tensorflow/core/kernels/function_ops.cc
+tensorflow/core/kernels/fused_batch_norm_op.cc
+tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/gather_op.cc
+tensorflow/core/kernels/identity_n_op.cc
+tensorflow/core/kernels/identity_op.cc
+tensorflow/core/kernels/immutable_constant_op.cc
+tensorflow/core/kernels/in_topk_op.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/inplace_ops.cc
+tensorflow/core/kernels/listdiff_op.cc
+tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
+tensorflow/core/kernels/lrn_op.cc
+tensorflow/core/kernels/matmul_op.cc
+tensorflow/core/kernels/maxpooling_op.cc
 tensorflow/core/kernels/meta_support.cc
+tensorflow/core/kernels/mfcc.cc
+tensorflow/core/kernels/mfcc_dct.cc
+tensorflow/core/kernels/mfcc_mel_filterbank.cc
+tensorflow/core/kernels/mfcc_op.cc
+tensorflow/core/kernels/mirror_pad_op.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
+tensorflow/core/kernels/no_op.cc
+tensorflow/core/kernels/non_max_suppression_op.cc
+tensorflow/core/kernels/one_hot_op.cc
+tensorflow/core/kernels/ops_util.cc
+tensorflow/core/kernels/pack_op.cc
+tensorflow/core/kernels/pad_op.cc
+tensorflow/core/kernels/padding_fifo_queue.cc
+tensorflow/core/kernels/padding_fifo_queue_op.cc
+tensorflow/core/kernels/pooling_ops_common.cc
 tensorflow/core/kernels/population_count_op.cc
 tensorflow/core/kernels/quantization_utils.cc
 tensorflow/core/kernels/quantize_down_and_shrink_range.cc
@@ -262,46 +174,135 @@
 tensorflow/core/kernels/quantized_pooling_ops.cc
 tensorflow/core/kernels/quantized_reshape_op.cc
 tensorflow/core/kernels/quantized_resize_bilinear_op.cc
-tensorflow/core/kernels/requantization_range_op.cc
-tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/queue_base.cc
+tensorflow/core/kernels/queue_op.cc
+tensorflow/core/kernels/queue_ops.cc
+tensorflow/core/kernels/random_op.cc
+tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/reduction_ops_any.cc
+tensorflow/core/kernels/reduction_ops_common.cc
+tensorflow/core/kernels/reduction_ops_max.cc
+tensorflow/core/kernels/reduction_ops_mean.cc
+tensorflow/core/kernels/reduction_ops_min.cc
+tensorflow/core/kernels/reduction_ops_prod.cc
+tensorflow/core/kernels/reduction_ops_sum.cc
+tensorflow/core/kernels/relu_op.cc
 tensorflow/core/kernels/remote_fused_graph_execute_op.cc
 tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
-tensorflow/core/kernels/batch_matmul_op_real.cc
-tensorflow/core/kernels/random_op.cc
-tensorflow/core/ops/training_ops.cc
-tensorflow/core/ops/string_ops.cc
-tensorflow/core/ops/state_ops.cc
-tensorflow/core/ops/sparse_ops.cc
-tensorflow/core/ops/sendrecv_ops.cc
-tensorflow/core/ops/script_ops.cc
-tensorflow/core/ops/remote_fused_graph_ops.cc
-tensorflow/core/ops/random_ops.cc
-tensorflow/core/ops/random_grad.cc
-tensorflow/core/ops/parsing_ops.cc
-tensorflow/core/ops/no_op.cc
-tensorflow/core/ops/nn_ops.cc
-tensorflow/core/ops/nn_grad.cc
-tensorflow/core/ops/manip_ops.cc
-tensorflow/core/ops/math_ops.cc
-tensorflow/core/ops/math_grad.cc
-tensorflow/core/ops/logging_ops.cc
-tensorflow/core/ops/linalg_ops.cc
-tensorflow/core/ops/io_ops.cc
-tensorflow/core/ops/image_ops.cc
-tensorflow/core/ops/functional_ops.cc
-tensorflow/core/ops/functional_grad.cc
-tensorflow/core/ops/function_ops.cc
-tensorflow/core/ops/data_flow_ops.cc
-tensorflow/core/ops/ctc_ops.cc
-tensorflow/core/ops/control_flow_ops.cc
-tensorflow/core/ops/candidate_sampling_ops.cc
-tensorflow/core/ops/boosted_trees_ops.cc
-tensorflow/core/ops/array_ops.cc
-tensorflow/core/ops/array_grad.cc
+tensorflow/core/kernels/requantization_range_op.cc
+tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/reshape_op.cc
+tensorflow/core/kernels/reshape_util.cc
+tensorflow/core/kernels/resize_bilinear_op.cc
+tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+tensorflow/core/kernels/restore_op.cc
+tensorflow/core/kernels/reverse_op.cc
+tensorflow/core/kernels/reverse_sequence_op.cc
+tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/save_op.cc
+tensorflow/core/kernels/save_restore_tensor.cc
+tensorflow/core/kernels/save_restore_v2_ops.cc
+tensorflow/core/kernels/scatter_functor.cc
+tensorflow/core/kernels/scatter_nd_op.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/scatter_op.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/sendrecv_ops.cc
+tensorflow/core/kernels/sequence_ops.cc
+tensorflow/core/kernels/session_ops.cc
+tensorflow/core/kernels/shape_ops.cc
+tensorflow/core/kernels/slice_op.cc
+tensorflow/core/kernels/slice_op_cpu_impl_1.cc
+tensorflow/core/kernels/slice_op_cpu_impl_2.cc
+tensorflow/core/kernels/slice_op_cpu_impl_3.cc
+tensorflow/core/kernels/slice_op_cpu_impl_4.cc
+tensorflow/core/kernels/slice_op_cpu_impl_5.cc
+tensorflow/core/kernels/slice_op_cpu_impl_6.cc
+tensorflow/core/kernels/slice_op_cpu_impl_7.cc
+tensorflow/core/kernels/softmax_op.cc
+tensorflow/core/kernels/softplus_op.cc
+tensorflow/core/kernels/softsign_op.cc
 tensorflow/core/kernels/spacetobatch_functor.cc
 tensorflow/core/kernels/spacetobatch_op.cc
-tensorflow/core/kernels/batchtospace_op.cc
-tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/sparse_to_dense_op.cc
+tensorflow/core/kernels/spectrogram.cc
+tensorflow/core/kernels/spectrogram_op.cc
+tensorflow/core/kernels/split_lib_cpu.cc
+tensorflow/core/kernels/split_op.cc
+tensorflow/core/kernels/split_v_op.cc
+tensorflow/core/kernels/stack_ops.cc
+tensorflow/core/kernels/strided_slice_op.cc
+tensorflow/core/kernels/strided_slice_op_inst_0.cc
+tensorflow/core/kernels/strided_slice_op_inst_1.cc
+tensorflow/core/kernels/strided_slice_op_inst_2.cc
+tensorflow/core/kernels/strided_slice_op_inst_3.cc
+tensorflow/core/kernels/strided_slice_op_inst_4.cc
+tensorflow/core/kernels/strided_slice_op_inst_5.cc
+tensorflow/core/kernels/strided_slice_op_inst_6.cc
+tensorflow/core/kernels/strided_slice_op_inst_7.cc
+tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/tensor_array.cc
+tensorflow/core/kernels/tensor_array_ops.cc
+tensorflow/core/kernels/tile_functor_cpu.cc
+tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
+tensorflow/core/kernels/topk_op.cc
+tensorflow/core/kernels/training_op_helpers.cc
+tensorflow/core/kernels/training_ops.cc
+tensorflow/core/kernels/transpose_functor_cpu.cc
+tensorflow/core/kernels/transpose_op.cc
+tensorflow/core/kernels/unique_op.cc
+tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/variable_ops.cc
+tensorflow/core/kernels/where_op.cc
+tensorflow/core/kernels/xent_op.cc
+tensorflow/core/kernels/xsmm_conv2d.cc
+tensorflow/core/ops/array_grad.cc
+tensorflow/core/ops/array_ops.cc
 tensorflow/core/ops/audio_ops.cc
-tensorflow/core/kernels/decode_proto_op.cc
-tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/ops/boosted_trees_ops.cc
+tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/control_flow_ops.cc
+tensorflow/core/ops/ctc_ops.cc
+tensorflow/core/ops/data_flow_ops.cc
+tensorflow/core/ops/function_ops.cc
+tensorflow/core/ops/functional_grad.cc
+tensorflow/core/ops/functional_ops.cc
+tensorflow/core/ops/image_ops.cc
+tensorflow/core/ops/io_ops.cc
+tensorflow/core/ops/linalg_ops.cc
+tensorflow/core/ops/logging_ops.cc
+tensorflow/core/ops/manip_ops.cc
+tensorflow/core/ops/math_grad.cc
+tensorflow/core/ops/math_ops.cc
+tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/nn_ops.cc
+tensorflow/core/ops/no_op.cc
+tensorflow/core/ops/parsing_ops.cc
+tensorflow/core/ops/random_grad.cc
+tensorflow/core/ops/random_ops.cc
+tensorflow/core/ops/remote_fused_graph_ops.cc
+tensorflow/core/ops/script_ops.cc
+tensorflow/core/ops/sendrecv_ops.cc
+tensorflow/core/ops/sparse_ops.cc
+tensorflow/core/ops/state_ops.cc
+tensorflow/core/ops/string_ops.cc
+tensorflow/core/ops/training_ops.cc
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index b5431df..e23f499 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,33 +1,34 @@
-tensorflow/core/util/saved_tensor_slice.pb_text.cc
-tensorflow/core/util/memmapped_file_system.pb_text.cc
-tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/example/example.pb_text.cc
+tensorflow/core/example/feature.pb_text.cc
+tensorflow/core/framework/allocation_description.pb_text.cc
+tensorflow/core/framework/api_def.pb_text.cc
+tensorflow/core/framework/attr_value.pb_text.cc
+tensorflow/core/framework/cost_graph.pb_text.cc
+tensorflow/core/framework/device_attributes.pb_text.cc
+tensorflow/core/framework/function.pb_text.cc
+tensorflow/core/framework/graph.pb_text.cc
+tensorflow/core/framework/graph_transfer_info.pb_text.cc
+tensorflow/core/framework/kernel_def.pb_text.cc
+tensorflow/core/framework/log_memory.pb_text.cc
+tensorflow/core/framework/model.pb_text.cc
+tensorflow/core/framework/node_def.pb_text.cc
+tensorflow/core/framework/op_def.pb_text.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
+tensorflow/core/framework/resource_handle.pb_text.cc
+tensorflow/core/framework/step_stats.pb_text.cc
+tensorflow/core/framework/summary.pb_text.cc
+tensorflow/core/framework/tensor.pb_text.cc
+tensorflow/core/framework/tensor_description.pb_text.cc
+tensorflow/core/framework/tensor_shape.pb_text.cc
+tensorflow/core/framework/tensor_slice.pb_text.cc
+tensorflow/core/framework/types.pb_text.cc
+tensorflow/core/framework/versions.pb_text.cc
+tensorflow/core/lib/core/error_codes.pb_text.cc
 tensorflow/core/protobuf/cluster.pb_text.cc
 tensorflow/core/protobuf/config.pb_text.cc
 tensorflow/core/protobuf/debug.pb_text.cc
 tensorflow/core/protobuf/rewriter_config.pb_text.cc
+tensorflow/core/protobuf/saver.pb_text.cc
 tensorflow/core/protobuf/tensor_bundle.pb_text.cc
-tensorflow/core/lib/core/error_codes.pb_text.cc
-tensorflow/core/framework/versions.pb_text.cc
-tensorflow/core/framework/types.pb_text.cc
-tensorflow/core/framework/tensor_slice.pb_text.cc
-tensorflow/core/framework/tensor_shape.pb_text.cc
-tensorflow/core/framework/tensor_description.pb_text.cc
-tensorflow/core/framework/tensor.pb_text.cc
-tensorflow/core/framework/summary.pb_text.cc
-tensorflow/core/framework/step_stats.pb_text.cc
-tensorflow/core/framework/resource_handle.pb_text.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
-tensorflow/core/framework/api_def.pb_text.cc
-tensorflow/core/framework/op_def.pb_text.cc
-tensorflow/core/framework/node_def.pb_text.cc
-tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/kernel_def.pb_text.cc
-tensorflow/core/framework/graph_transfer_info.pb_text.cc
-tensorflow/core/framework/graph.pb_text.cc
-tensorflow/core/framework/function.pb_text.cc
-tensorflow/core/framework/device_attributes.pb_text.cc
-tensorflow/core/framework/cost_graph.pb_text.cc
-tensorflow/core/framework/attr_value.pb_text.cc
-tensorflow/core/framework/allocation_description.pb_text.cc
-tensorflow/core/example/feature.pb_text.cc
-tensorflow/core/example/example.pb_text.cc
+tensorflow/core/util/memmapped_file_system.pb_text.cc
+tensorflow/core/util/saved_tensor_slice.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 1f25469..5eae845 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -2,47 +2,48 @@
 tensorflow/contrib/boosted_trees/proto/quantiles.proto
 tensorflow/contrib/boosted_trees/proto/split_info.proto
 tensorflow/contrib/boosted_trees/proto/tree_config.proto
-tensorflow/core/util/test_log.proto
-tensorflow/core/util/saved_tensor_slice.proto
-tensorflow/core/util/memmapped_file_system.proto
-tensorflow/core/util/event.proto
-tensorflow/core/protobuf/tensorflow_server.proto
-tensorflow/core/protobuf/saver.proto
-tensorflow/core/protobuf/queue_runner.proto
-tensorflow/core/protobuf/named_tensor.proto
-tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/example/example.proto
+tensorflow/core/example/feature.proto
+tensorflow/core/framework/allocation_description.proto
+tensorflow/core/framework/api_def.proto
+tensorflow/core/framework/attr_value.proto
+tensorflow/core/framework/cost_graph.proto
+tensorflow/core/framework/device_attributes.proto
+tensorflow/core/framework/function.proto
+tensorflow/core/framework/graph.proto
+tensorflow/core/framework/graph_transfer_info.proto
+tensorflow/core/framework/kernel_def.proto
+tensorflow/core/framework/log_memory.proto
+tensorflow/core/framework/model.proto
+tensorflow/core/framework/node_def.proto
+tensorflow/core/framework/op_def.proto
+tensorflow/core/framework/reader_base.proto
+tensorflow/core/framework/remote_fused_graph_execute_info.proto
+tensorflow/core/framework/resource_handle.proto
+tensorflow/core/framework/step_stats.proto
+tensorflow/core/framework/summary.proto
+tensorflow/core/framework/tensor.proto
+tensorflow/core/framework/tensor_description.proto
+tensorflow/core/framework/tensor_shape.proto
+tensorflow/core/framework/tensor_slice.proto
+tensorflow/core/framework/types.proto
+tensorflow/core/framework/variable.proto
+tensorflow/core/framework/versions.proto
+tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+tensorflow/core/lib/core/error_codes.proto
 tensorflow/core/protobuf/cluster.proto
 tensorflow/core/protobuf/config.proto
 tensorflow/core/protobuf/debug.proto
 tensorflow/core/protobuf/device_properties.proto
+tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/named_tensor.proto
+tensorflow/core/protobuf/queue_runner.proto
 tensorflow/core/protobuf/rewriter_config.proto
+tensorflow/core/protobuf/saver.proto
 tensorflow/core/protobuf/tensor_bundle.proto
-tensorflow/core/lib/core/error_codes.proto
-tensorflow/core/kernels/boosted_trees/boosted_trees.proto
-tensorflow/core/framework/versions.proto
-tensorflow/core/framework/variable.proto
-tensorflow/core/framework/types.proto
-tensorflow/core/framework/tensor_slice.proto
-tensorflow/core/framework/tensor_shape.proto
-tensorflow/core/framework/tensor_description.proto
-tensorflow/core/framework/tensor.proto
-tensorflow/core/framework/summary.proto
-tensorflow/core/framework/step_stats.proto
-tensorflow/core/framework/resource_handle.proto
-tensorflow/core/framework/remote_fused_graph_execute_info.proto
-tensorflow/core/framework/reader_base.proto
-tensorflow/core/framework/api_def.proto
-tensorflow/core/framework/op_def.proto
-tensorflow/core/framework/node_def.proto
-tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/kernel_def.proto
-tensorflow/core/framework/graph_transfer_info.proto
-tensorflow/core/framework/graph.proto
-tensorflow/core/framework/function.proto
-tensorflow/core/framework/device_attributes.proto
-tensorflow/core/framework/cost_graph.proto
-tensorflow/core/framework/attr_value.proto
-tensorflow/core/framework/allocation_description.proto
-tensorflow/core/example/feature.proto
-tensorflow/core/example/example.proto
-tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/protobuf/tensorflow_server.proto
+tensorflow/core/util/event.proto
+tensorflow/core/util/memmapped_file_system.proto
+tensorflow/core/util/saved_tensor_slice.proto
+tensorflow/core/util/test_log.proto
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index c35e60a..b1c852c 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -31,6 +31,7 @@
 from tensorflow.python.framework import graph_util as _graph_util
 from tensorflow.python.framework import importer as _importer
 from tensorflow.python.framework import ops as _ops
+from tensorflow.python.platform import tf_logging as _logging
 from tensorflow.python.saved_model import constants as _saved_model_constants
 from tensorflow.python.training import saver as _saver_lib
 from tensorflow.python.util import compat as _compat
@@ -476,6 +477,12 @@
     collection.bytes_list.value[:] = [
         s for s in base_collection.bytes_list.value
         if not _is_removed_mentioned(s, removed_op_names)]
+    _logging.info(
+        'In collection %s, nodes excluded are: %s', collection_name,
+        sorted([
+            s for s in base_collection.bytes_list.value
+            if _is_removed_mentioned(s, removed_op_names)
+        ]))
   elif base_collection.HasField('node_list'):
     collection.node_list.value[:] = [
         s for s in base_collection.node_list.value
@@ -745,6 +752,9 @@
   retained_op_names = [_compat.as_str(node.name)
                        for node in meta_graph_def.graph_def.node]
   removed_op_names = set(base_op_names) - set(retained_op_names)
+  _logging.info('Node names in base graph: %s', sorted(base_op_names))
+  _logging.info('Node names retained: %s', sorted(retained_op_names))
+  _logging.info('Node names removed: %s', sorted(removed_op_names))
 
   # Copy saver, excluding any pruned nodes if graph was not frozen.
   # TODO(b/63447631): Revisit this once the problem is addressed. Currently
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
index 7acfc38..5777e64 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -47,7 +47,7 @@
     # code used float32 for accumulation.
     num_updates = 71
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for _ in xrange(num_updates):
         sess.run(update_op)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 024bd54..955b83b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -178,7 +178,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -195,7 +195,7 @@
       self.assertAlmostEqual(1.65, sess.run(mean), 5)
 
   def testUpdateOpsReturnsCurrentValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -216,7 +216,7 @@
       self.assertAlmostEqual(1.65, sess.run(mean), 5)
 
   def test1dWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -243,7 +243,7 @@
       self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
 
   def test1dWeightedValues_placeholders(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
       values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -265,7 +265,7 @@
       self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
 
   def test2dWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -292,7 +292,7 @@
       self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5)
 
   def test2dWeightedValues_placeholders(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
       values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -337,7 +337,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -354,7 +354,7 @@
       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
 
   def testMultiDimensional(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
       _enqueue_vector(
@@ -375,7 +375,7 @@
       self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
 
   def testUpdateOpsReturnsCurrentValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -396,7 +396,7 @@
       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
 
   def testWeighted1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -423,7 +423,7 @@
       self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
 
   def testWeighted2d_1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -450,7 +450,7 @@
       self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
 
   def testWeighted2d_2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -526,7 +526,7 @@
         (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2)
     accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -539,7 +539,7 @@
         self.assertEqual(initial_accuracy, accuracy.eval())
 
   def testMultipleUpdates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -569,7 +569,7 @@
   def testEffectivelyEquivalentSizes(self):
     predictions = array_ops.ones((40, 1))
     labels = array_ops.ones((40,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
 
       sess.run(variables.local_variables_initializer())
@@ -583,7 +583,7 @@
     weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
                                     1)  # shape 3, 1
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
                                                        weights)
 
@@ -604,7 +604,7 @@
         dtype=dtypes_lib.int32, name='weights')
     feed_dict = {weights_placeholder: weights}
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
                                                        weights_placeholder)
 
@@ -616,7 +616,7 @@
       self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
 
   def testMultipleUpdatesWithWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -681,7 +681,7 @@
           tp, tp_update_op = metrics.streaming_true_positives(
               predictions, labels)
 
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             sess.run(variables.local_variables_initializer())
             self.assertEqual(0, tp.eval())
             self.assertEqual(1, tp_update_op.eval())
@@ -698,7 +698,7 @@
       tp, tp_update_op = metrics.streaming_true_positives(
           predictions, labels, weights=37.0)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertEqual(0, tp.eval())
         self.assertEqual(37.0, tp_update_op.eval())
@@ -732,7 +732,7 @@
           fn, fn_update_op = metrics.streaming_false_negatives(
               predictions, labels)
 
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             sess.run(variables.local_variables_initializer())
             self.assertEqual(0, fn.eval())
             self.assertEqual(2, fn_update_op.eval())
@@ -749,7 +749,7 @@
       fn, fn_update_op = metrics.streaming_false_negatives(
           predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertEqual(0, fn.eval())
         self.assertEqual(8.0, fn_update_op.eval())
@@ -783,7 +783,7 @@
           fp, fp_update_op = metrics.streaming_false_positives(
               predictions, labels)
 
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             sess.run(variables.local_variables_initializer())
             self.assertEqual(0, fp.eval())
             self.assertEqual(4, fp_update_op.eval())
@@ -803,7 +803,7 @@
           weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
                                                                    29.0, 31.0)))
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertEqual(0, fp.eval())
         self.assertEqual(42.0, fp_update_op.eval())
@@ -837,7 +837,7 @@
           tn, tn_update_op = metrics.streaming_true_negatives(
               predictions, labels)
 
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             sess.run(variables.local_variables_initializer())
             self.assertEqual(0, tn.eval())
             self.assertEqual(5, tn_update_op.eval())
@@ -854,7 +854,7 @@
       tn, tn_update_op = metrics.streaming_true_negatives(
           predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertEqual(0, tn.eval())
         self.assertEqual(15.0, tn_update_op.eval())
@@ -879,7 +879,7 @@
     tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
         predictions, labels, thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), tp.eval())
       self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -892,7 +892,7 @@
     tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
         predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
       self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
@@ -921,7 +921,7 @@
     fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
         predictions, labels, thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), fn.eval())
       self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -937,7 +937,7 @@
         weights=((3.0,), (5.0,), (7.0,)),
         thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
       self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -962,7 +962,7 @@
     fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
         predictions, labels, thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), fp.eval())
       self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -979,7 +979,7 @@
                                                                  29.0, 31.0)),
         thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
       self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -1004,7 +1004,7 @@
     tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
         predictions, labels, thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), tn.eval())
       self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -1020,7 +1020,7 @@
         weights=((0.0, 2.0, 3.0, 5.0),),
         thresholds=(0.15, 0.5, 0.85))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
       self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -1062,7 +1062,7 @@
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     precision, update_op = metrics.streaming_precision(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1081,7 +1081,7 @@
     labels = constant_op.constant(inputs)
     precision, update_op = metrics.streaming_precision(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1, sess.run(update_op))
       self.assertAlmostEqual(1, precision.eval())
@@ -1091,7 +1091,7 @@
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     precision, update_op = metrics.streaming_precision(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.5, update_op.eval())
       self.assertAlmostEqual(0.5, precision.eval())
@@ -1102,7 +1102,7 @@
     precision, update_op = metrics.streaming_precision(
         predictions, labels, weights=constant_op.constant([[2], [5]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 2.0 + 5.0
       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1120,7 +1120,7 @@
     precision, update_op = metrics.streaming_precision(
         predictions, labels, weights=constant_op.constant([[2], [5]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 2.0 + 5.0
       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1138,7 +1138,7 @@
         labels,
         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 3.0 + 4.0
       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1158,7 +1158,7 @@
         labels,
         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 3.0 + 4.0
       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1175,7 +1175,7 @@
     labels = constant_op.constant(1 - inputs)
     precision, update_op = metrics.streaming_precision(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertAlmostEqual(0, precision.eval())
@@ -1185,7 +1185,7 @@
     labels = constant_op.constant([0, 0, 0, 0])
     precision, update_op = metrics.streaming_precision(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0.0, precision.eval())
@@ -1227,7 +1227,7 @@
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     recall, update_op = metrics.streaming_recall(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1246,7 +1246,7 @@
     labels = constant_op.constant(np_inputs)
     recall, update_op = metrics.streaming_recall(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(1, recall.eval())
@@ -1256,7 +1256,7 @@
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     recall, update_op = metrics.streaming_recall(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.5, update_op.eval())
       self.assertAlmostEqual(0.5, recall.eval())
@@ -1268,7 +1268,7 @@
     recall, update_op = metrics.streaming_recall(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_tp = 2.0 + 5.0
       weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1283,7 +1283,7 @@
     recall, update_op = metrics.streaming_recall(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_tp = 3.0 + 1.0
       weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1298,7 +1298,7 @@
     labels = constant_op.constant(1 - np_inputs)
     recall, update_op = metrics.streaming_recall(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, recall.eval())
@@ -1308,7 +1308,7 @@
     labels = array_ops.zeros((1, 4))
     recall, update_op = metrics.streaming_recall(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, recall.eval())
@@ -1350,7 +1350,7 @@
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1369,7 +1369,7 @@
     labels = constant_op.constant(np_inputs)
     fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, fpr.eval())
@@ -1379,7 +1379,7 @@
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.5, update_op.eval())
       self.assertAlmostEqual(0.5, fpr.eval())
@@ -1391,7 +1391,7 @@
     fpr, update_op = metrics.streaming_false_positive_rate(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_fp = 2.0 + 5.0
       weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1406,7 +1406,7 @@
     fpr, update_op = metrics.streaming_false_positive_rate(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_fp = 1.0 + 3.0
       weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
@@ -1421,7 +1421,7 @@
     labels = constant_op.constant(1 - np_inputs)
     fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(1, fpr.eval())
@@ -1431,7 +1431,7 @@
     labels = array_ops.ones((1, 4))
     fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, fpr.eval())
@@ -1473,7 +1473,7 @@
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1492,7 +1492,7 @@
     labels = constant_op.constant(np_inputs)
     fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, fnr.eval())
@@ -1502,7 +1502,7 @@
     labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
     fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.5, update_op.eval())
       self.assertAlmostEqual(0.5, fnr.eval())
@@ -1514,7 +1514,7 @@
     fnr, update_op = metrics.streaming_false_negative_rate(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_fn = 2.0 + 5.0
       weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1529,7 +1529,7 @@
     fnr, update_op = metrics.streaming_false_negative_rate(
         predictions, labels, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_fn = 2.0 + 4.0
       weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
@@ -1544,7 +1544,7 @@
     labels = constant_op.constant(1 - np_inputs)
     fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(1, fnr.eval())
@@ -1554,7 +1554,7 @@
     labels = array_ops.zeros((1, 4))
     fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, fnr.eval())
@@ -1599,7 +1599,7 @@
     points, update_op = metric_ops.streaming_curve_points(
         labels, predictions=predictions, curve=curve)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       sess.run(update_op)
@@ -1615,7 +1615,7 @@
     self._testValueTensorIsIdempotent(curve='PR')
 
   def _testCase(self, labels, predictions, curve, expected_points):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions_tensor = constant_op.constant(
           predictions, dtype=dtypes_lib.float32)
       labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
@@ -1717,7 +1717,7 @@
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     auc, update_op = metrics.streaming_auc(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1730,7 +1730,7 @@
         self.assertAlmostEqual(initial_auc, auc.eval(), 5)
 
   def testPredictionsOutOfRange(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1744,7 +1744,7 @@
   def allCorrectAsExpected(self, curve):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve)
@@ -1755,7 +1755,7 @@
       self.assertEqual(1, auc.eval())
 
   def testSomeCorrect(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1767,7 +1767,7 @@
       self.assertAlmostEqual(0.5, auc.eval())
 
   def testWeighted1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1781,7 +1781,7 @@
       self.assertAlmostEqual(0.5, auc.eval(), 5)
 
   def testWeighted2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1795,7 +1795,7 @@
       self.assertAlmostEqual(0.7, auc.eval(), 5)
 
   def testAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1807,7 +1807,7 @@
       self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
 
   def testAnotherAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
           shape=(1, 7),
@@ -1821,7 +1821,7 @@
       self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
 
   def testThirdAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
           shape=(1, 7),
@@ -1837,7 +1837,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1848,7 +1848,7 @@
       self.assertAlmostEqual(0, auc.eval())
 
   def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1859,7 +1859,7 @@
       self.assertAlmostEqual(1, auc.eval(), 6)
 
   def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
       labels = array_ops.ones([4])
       auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
@@ -1893,7 +1893,7 @@
                     np.random.exponential(scale=1.0, size=num_samples)):
       expected_auc = _np_auc(predictions, labels, weights)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         enqueue_ops = [[] for i in range(num_batches)]
         tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
         tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1966,7 +1966,7 @@
     labels = random_ops.random_uniform(
         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
     auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       # Run several updates.
       for _ in xrange(10):
@@ -1977,7 +1977,7 @@
         self.assertAlmostEqual(initial_auc, auc.eval(), 5)
 
   def testAllLabelsOnes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1., 1., 1.])
       labels = constant_op.constant([1, 1, 1])
       auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1986,7 +1986,7 @@
       self.assertEqual(0, auc.eval())
 
   def testAllLabelsZeros(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1., 1., 1.])
       labels = constant_op.constant([0, 0, 0])
       auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1995,7 +1995,7 @@
       self.assertEqual(0, auc.eval())
 
   def testNonZeroOnePredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
       labels = constant_op.constant([1, 0, 1, 0])
@@ -2006,7 +2006,7 @@
 
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs)
       labels = constant_op.constant(inputs)
       auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2015,7 +2015,7 @@
       self.assertEqual(1, auc.eval())
 
   def testSomeCorrect(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0, 1, 0])
       labels = constant_op.constant([0, 1, 1, 0])
       auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2025,7 +2025,7 @@
 
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2034,7 +2034,7 @@
       self.assertAlmostEqual(0, auc.eval())
 
   def testExceptionOnIncompatibleShapes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.ones([5])
       labels = array_ops.zeros([6])
       with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
@@ -2043,7 +2043,7 @@
         sess.run(update_op)
 
   def testExceptionOnGreaterThanOneLabel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
       labels = constant_op.constant([2, 1, 0])
       _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2054,7 +2054,7 @@
         sess.run(update_op)
 
   def testExceptionOnNegativeLabel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
       labels = constant_op.constant([1, 0, -1])
       _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2078,7 +2078,7 @@
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         dtype=dtypes_lib.float32)
     auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for _ in xrange(num_batches):
         new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2093,7 +2093,7 @@
         self.assertAlmostEqual(expected_auc, auc.eval())
 
   def testAUCPRReverseIncreasingPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 1])
@@ -2104,7 +2104,7 @@
       self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
 
   def testAUCPRJumbledPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
@@ -2115,7 +2115,7 @@
       self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
 
   def testAUCPRPredictionsLessThanHalf(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
           shape=(1, 7),
@@ -2148,7 +2148,7 @@
     auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
                                                    tf_predictions,
                                                    weights=tf_weights)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for _ in xrange(num_batches):
         new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2196,7 +2196,7 @@
       expected_result: The expected result (dict) that maps to tensors.
       weights: Optional weights tensor.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions_tensor = constant_op.constant(
           predictions, dtype=dtypes_lib.float32)
       labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
@@ -2320,7 +2320,7 @@
         dtype=dtypes_lib.float32)
     auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
                                                            tf_predictions)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for _ in xrange(num_batches):
         new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2335,7 +2335,7 @@
         self.assertAllClose(expected_auc, auc.auc.eval())
 
   def testExceptionOnFloatLabels(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
       labels = constant_op.constant([0.7, 0, 1, 0, 1])
       _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2343,7 +2343,7 @@
       self.assertRaises(TypeError, sess.run(update_op))
 
   def testExceptionOnGreaterThanOneLabel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
       labels = constant_op.constant([2, 1, 0, 1, 0])
       _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2354,7 +2354,7 @@
         sess.run(update_op)
 
   def testExceptionOnNegativeLabel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
       labels = constant_op.constant([1, 0, -1, 1, 0])
       _, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2415,7 +2415,7 @@
     result, update_op = metric_ops.precision_recall_at_equal_thresholds(
         labels=labels, predictions=predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Run several updates.
       sess.run(variables.local_variables_initializer())
       for _ in range(3):
@@ -2448,7 +2448,7 @@
         default from assertAllClose.
       weights: Optional weights tensor.
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions_tensor = constant_op.constant(predictions, dtype=dtype)
       labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
       weights_tensor = None
@@ -2621,7 +2621,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, sensitivity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2641,7 +2641,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, sensitivity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, specificity.eval())
@@ -2656,7 +2656,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, sensitivity=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1.0, sess.run(update_op))
       self.assertAlmostEqual(1.0, specificity.eval())
@@ -2671,7 +2671,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, sensitivity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2689,7 +2689,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, weights=weights, sensitivity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2707,7 +2707,7 @@
     specificity, update_op = metrics.streaming_specificity_at_sensitivity(
         predictions, labels, weights=weights, sensitivity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -2757,7 +2757,7 @@
     sensitivity, update_op = metrics.streaming_sensitivity_at_specificity(
         predictions, labels, specificity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2777,7 +2777,7 @@
     specificity, update_op = metrics.streaming_sensitivity_at_specificity(
         predictions, labels, specificity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, specificity.eval())
@@ -2792,7 +2792,7 @@
     specificity, update_op = metrics.streaming_sensitivity_at_specificity(
         predictions, labels, specificity=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.8, sess.run(update_op))
       self.assertAlmostEqual(0.8, specificity.eval())
@@ -2807,7 +2807,7 @@
     specificity, update_op = metrics.streaming_sensitivity_at_specificity(
         predictions, labels, specificity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.6, sess.run(update_op))
       self.assertAlmostEqual(0.6, specificity.eval())
@@ -2824,7 +2824,7 @@
     specificity, update_op = metrics.streaming_sensitivity_at_specificity(
         predictions, labels, weights=weights, specificity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.675, sess.run(update_op))
       self.assertAlmostEqual(0.675, specificity.eval())
@@ -2887,7 +2887,7 @@
     rec, rec_op = metrics.streaming_recall_at_thresholds(
         predictions, labels, thresholds)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2905,7 +2905,7 @@
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       thresholds = [0.5]
@@ -2921,7 +2921,7 @@
       self.assertEqual(1, rec.eval())
 
   def testSomeCorrect(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -2940,7 +2940,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       thresholds = [0.5]
@@ -2956,7 +2956,7 @@
       self.assertAlmostEqual(0, rec.eval())
 
   def testWeights1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -2982,7 +2982,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
 
   def testWeights2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3008,7 +3008,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
 
   def testExtremeThresholds(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3032,7 +3032,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval())
 
   def testZeroLabelsPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       thresholds = [0.5]
@@ -3082,7 +3082,7 @@
     labels = labels.astype(np.float32)
     predictions = predictions.astype(np.float32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Reshape the data so its easy to queue up:
       predictions_batches = predictions.reshape((batch_size, num_batches))
       labels_batches = labels.reshape((batch_size, num_batches))
@@ -3162,7 +3162,7 @@
     fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
         predictions, labels, thresholds)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3177,7 +3177,7 @@
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       thresholds = [0.5]
@@ -3190,7 +3190,7 @@
       self.assertEqual(0, fpr.eval())
 
   def testSomeCorrect(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3206,7 +3206,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       thresholds = [0.5]
@@ -3219,7 +3219,7 @@
       self.assertAlmostEqual(1, fpr.eval())
 
   def testWeights1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3239,7 +3239,7 @@
       self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
 
   def testWeights2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3259,7 +3259,7 @@
       self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
 
   def testExtremeThresholds(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3277,7 +3277,7 @@
       self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
 
   def testZeroLabelsPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       thresholds = [0.5]
@@ -3317,7 +3317,7 @@
     labels = labels.astype(np.float32)
     predictions = predictions.astype(np.float32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Reshape the data so its easy to queue up:
       predictions_batches = predictions.reshape((batch_size, num_batches))
       labels_batches = labels.reshape((batch_size, num_batches))
@@ -3393,7 +3393,7 @@
     recall, update_op = metrics.recall_at_precision(
         labels, predictions, precision=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3413,7 +3413,7 @@
     recall, update_op = metrics.recall_at_precision(
         labels, predictions, precision=1.0)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, recall.eval())
@@ -3428,7 +3428,7 @@
     recall, update_op = metrics.recall_at_precision(
         labels, predictions, precision=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.8, sess.run(update_op))
       self.assertAlmostEqual(0.8, recall.eval())
@@ -3443,7 +3443,7 @@
     recall, update_op = metrics.recall_at_precision(
         labels, predictions, precision=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       target_recall = 2.0 / 3.0
       self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3461,7 +3461,7 @@
     recall, update_op = metrics.recall_at_precision(
         labels, predictions, weights=weights, precision=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       target_recall = 2.0 / 3.0
       self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3486,7 +3486,7 @@
         precision=target_precision,
         strict_mode=strict_mode)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(expected_recall, sess.run(update_op))
       self.assertAlmostEqual(expected_recall, recall.eval())
@@ -3565,7 +3565,7 @@
     precision, update_op = metrics.precision_at_recall(
         labels, predictions, target_recall=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3585,7 +3585,7 @@
     precision, update_op = metrics.precision_at_recall(
         labels, predictions, target_recall=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, precision.eval())
@@ -3599,7 +3599,7 @@
     precision, update_op = metrics.precision_at_recall(
         labels, predictions, target_recall=0.2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(sess.run(label_prior), sess.run(update_op))
       self.assertEqual(sess.run(label_prior), precision.eval())
@@ -3614,7 +3614,7 @@
     precision, update_op = metrics.precision_at_recall(
         labels, predictions, target_recall=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.8, sess.run(update_op))
       self.assertAlmostEqual(0.8, precision.eval())
@@ -3629,7 +3629,7 @@
     precision, update_op = metrics.precision_at_recall(
         labels, predictions, target_recall=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(2.0/3, sess.run(update_op))
       self.assertAlmostEqual(2.0/3, precision.eval())
@@ -3648,7 +3648,7 @@
       precision, update_op = metrics.precision_at_recall(
           labels, predictions, target_recall=0.8, weights=weights)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertAlmostEqual(34.0/43, sess.run(update_op))
         self.assertAlmostEqual(34.0/43, precision.eval())
@@ -3697,7 +3697,7 @@
     fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
         predictions, labels, thresholds)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3712,7 +3712,7 @@
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       thresholds = [0.5]
@@ -3725,7 +3725,7 @@
       self.assertEqual(0, fnr.eval())
 
   def testSomeCorrect(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3741,7 +3741,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       thresholds = [0.5]
@@ -3754,7 +3754,7 @@
       self.assertAlmostEqual(1, fnr.eval())
 
   def testWeights1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3774,7 +3774,7 @@
       self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
 
   def testWeights2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3794,7 +3794,7 @@
       self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
 
   def testExtremeThresholds(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3812,7 +3812,7 @@
       self.assertAlmostEqual(1.0, fnr_high.eval())
 
   def testZeroLabelsPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       thresholds = [0.5]
@@ -3852,7 +3852,7 @@
     labels = labels.astype(np.float32)
     predictions = predictions.astype(np.float32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Reshape the data so its easy to queue up:
       predictions_batches = predictions.reshape((batch_size, num_batches))
       labels_batches = labels.reshape((batch_size, num_batches))
@@ -3940,7 +3940,7 @@
     sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
         predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0.25, sess.run(update_op))
       self.assertEqual(0.25, recall.eval())
@@ -3958,7 +3958,7 @@
     sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
         predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0.5, sess.run(update_op))
       self.assertEqual(0.5, recall.eval())
@@ -3976,7 +3976,7 @@
     sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
         predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1.0, sess.run(update_op))
       self.assertEqual(1.0, recall.eval())
@@ -4000,7 +4000,7 @@
         k=2,
         weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1.0, sess.run(update_op))
       self.assertEqual(1.0, recall.eval())
@@ -4122,7 +4122,7 @@
         self.assertAlmostEqual(expected, metric.eval())
 
   def test_top_k_rank_invalid(self):
-    with self.test_session():
+    with self.cached_session():
       # top_k_predictions has rank < 2.
       top_k_predictions = [9, 4, 6, 2, 0]
       sp_labels = sparse_tensor.SparseTensorValue(
@@ -4669,7 +4669,7 @@
     predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
     labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
     expected_precision = 0.5
-    with self.test_session():
+    with self.cached_session():
       _, precision = metrics.streaming_sparse_precision_at_k(
           predictions=constant_op.constant(predictions, dtypes_lib.float32),
           labels=_binary_2d_label_to_sparse_value(labels),
@@ -5374,7 +5374,7 @@
     predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
     labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
     expected_recall = 0.5
-    with self.test_session():
+    with self.cached_session():
       _, recall = metrics.streaming_sparse_recall_at_k(
           predictions=constant_op.constant(predictions, dtypes_lib.float32),
           labels=_binary_2d_label_to_sparse_value(labels),
@@ -5418,7 +5418,7 @@
     error, update_op = metrics.streaming_mean_absolute_error(
         predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -5440,7 +5440,7 @@
     error, update_op = metrics.streaming_mean_absolute_error(
         predictions, labels, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(3, sess.run(update_op))
       self.assertEqual(3, error.eval())
@@ -5484,7 +5484,7 @@
     error, update_op = metrics.streaming_mean_relative_error(
         predictions, labels, normalizer)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -5509,7 +5509,7 @@
     error, update_op = metrics.streaming_mean_relative_error(
         predictions, labels, normalizer=labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(expected_error, sess.run(update_op))
       self.assertEqual(expected_error, error.eval())
@@ -5525,7 +5525,7 @@
     error, update_op = metrics.streaming_mean_relative_error(
         predictions, labels, normalizer=array_ops.zeros_like(labels))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0.0, sess.run(update_op))
       self.assertEqual(0.0, error.eval())
@@ -5563,7 +5563,7 @@
     labels = random_ops.random_normal((10, 3), seed=2)
     error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -5581,7 +5581,7 @@
 
     error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -5594,7 +5594,7 @@
 
     error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(6, sess.run(update_op))
       self.assertEqual(6, error.eval())
@@ -5609,13 +5609,13 @@
     error, update_op = metrics.streaming_mean_squared_error(
         predictions, labels, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(13, sess.run(update_op))
       self.assertEqual(13, error.eval())
 
   def testMultipleBatchesOfSizeOne(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5640,7 +5640,7 @@
       self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
 
   def testMetricsComputedConcurrently(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates one set of predictions.
       preds_queue0 = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5683,7 +5683,7 @@
       self.assertAlmostEqual(79.0 / 6, mse1, 5)
 
   def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5745,7 +5745,7 @@
     error, update_op = metrics.streaming_root_mean_squared_error(
         predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -5758,7 +5758,7 @@
         self.assertEqual(initial_error, error.eval())
 
   def testSingleUpdateZeroError(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           0.0, shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -5772,7 +5772,7 @@
       self.assertEqual(0, rmse.eval())
 
   def testSingleUpdateWithError(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -5786,7 +5786,7 @@
       self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
 
   def testSingleUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -5842,7 +5842,7 @@
     predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
     cov, update_op = metrics.streaming_covariance(predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -5855,7 +5855,7 @@
         self.assertEqual(initial_cov, cov.eval())
 
   def testSingleUpdateIdentical(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = math_ops.to_float(math_ops.range(10))
       labels = math_ops.to_float(math_ops.range(10))
 
@@ -5867,7 +5867,7 @@
       self.assertAlmostEqual(expected_cov, cov.eval(), 5)
 
   def testSingleUpdateNonIdentical(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -5881,7 +5881,7 @@
       self.assertAlmostEqual(expected_cov, cov.eval())
 
   def testSingleUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -5899,7 +5899,7 @@
       self.assertAlmostEqual(expected_cov, cov.eval())
 
   def testMultiUpdateWithErrorNoWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np.random.seed(123)
       n = 100
       predictions = np.random.randn(n)
@@ -5933,7 +5933,7 @@
         prev_expected_cov = expected_cov
 
   def testMultiUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np.random.seed(123)
       n = 100
       predictions = np.random.randn(n)
@@ -6023,7 +6023,7 @@
     pearson_r, update_op = metrics.streaming_pearson_correlation(
         predictions, labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -6036,7 +6036,7 @@
         self.assertEqual(initial_r, pearson_r.eval())
 
   def testSingleUpdateIdentical(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = math_ops.to_float(math_ops.range(10))
       labels = math_ops.to_float(math_ops.range(10))
 
@@ -6049,7 +6049,7 @@
       self.assertAlmostEqual(expected_r, pearson_r.eval(), 5)
 
   def testSingleUpdateNonIdentical(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -6064,7 +6064,7 @@
       self.assertAlmostEqual(expected_r, pearson_r.eval())
 
   def testSingleUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = np.array([2, 4, 6, 8])
       labels = np.array([1, 3, 2, 7])
       weights = np.array([0, 1, 3, 1])
@@ -6085,7 +6085,7 @@
       self.assertAlmostEqual(expected_r, pearson_r.eval())
 
   def testMultiUpdateWithErrorNoWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np.random.seed(123)
       n = 100
       predictions = np.random.randn(n)
@@ -6120,7 +6120,7 @@
         prev_expected_r = expected_r
 
   def testMultiUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np.random.seed(123)
       n = 100
       predictions = np.random.randn(n)
@@ -6162,7 +6162,7 @@
         prev_expected_r = expected_r
 
   def testMultiUpdateWithErrorAndSingletonBatches(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np.random.seed(123)
       n = 100
       predictions = np.random.randn(n)
@@ -6243,7 +6243,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=1)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -6266,7 +6266,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -6283,7 +6283,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1, sess.run(update_op), 5)
       self.assertAlmostEqual(1, error.eval(), 5)
@@ -6305,7 +6305,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1.0, sess.run(update_op), 5)
       self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -6324,7 +6324,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=2, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -6343,7 +6343,7 @@
     error, update_op = metrics.streaming_mean_cosine_distance(
         predictions, labels, dim=2, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1.5, update_op.eval())
       self.assertEqual(1.5, error.eval())
@@ -6378,7 +6378,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testOneUpdate(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
 
@@ -6398,7 +6398,7 @@
       self.assertAlmostEqual(0.0, pcnt2, 5)
 
   def testSomePresentOneUpdate(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
       weights = constant_op.constant(
@@ -6475,7 +6475,7 @@
     miou, update_op = metrics.streaming_mean_iou(
         predictions, labels, num_classes=num_classes)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -6489,7 +6489,7 @@
 
   def testMultipleUpdates(self):
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6521,7 +6521,7 @@
 
   def testMultipleUpdatesWithWeights(self):
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6569,7 +6569,7 @@
     # one class, and thus there is one row and one column with
     # zero entries in the confusion matrix.
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       # There is no prediction for class 2.
       preds_queue = data_flow_ops.FIFOQueue(
@@ -6611,7 +6611,7 @@
         constant_op.constant(1, shape=[7])
     ], 0)
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6624,7 +6624,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.zeros([40])
     num_classes = 1
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6635,7 +6635,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.ones([40])
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6657,7 +6657,7 @@
         constant_op.constant(1, shape=[8]),
         constant_op.constant(0, shape=[1])
     ], 0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(
           predictions, labels, num_classes, weights=weights)
       sess.run(variables.local_variables_initializer())
@@ -6672,7 +6672,7 @@
         [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
                                                     [1, 1, 2, 0, 0, 0]]])
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6684,7 +6684,7 @@
     labels = constant_op.constant([0])
     predictions = constant_op.constant([0])
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6698,7 +6698,7 @@
         [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
                                                     [1, 1, 1, 0, 0, 0]]])
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.streaming_mean_iou(predictions, labels,
                                                    num_classes)
       sess.run(variables.local_variables_initializer())
@@ -6733,7 +6733,7 @@
 
   def testNextArraySize(self):
     next_array_size = metric_ops._next_array_size  # pylint: disable=protected-access
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
       self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
       self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
@@ -6741,7 +6741,7 @@
       self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
 
   def testStreamingConcat(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = array_ops.placeholder(dtypes_lib.int32, [None])
       concatenated, update_op = metrics.streaming_concat(values)
       sess.run(variables.local_variables_initializer())
@@ -6758,7 +6758,7 @@
       self.assertAllEqual(np.arange(10), concatenated.eval())
 
   def testStreamingConcatStringValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = array_ops.placeholder(dtypes_lib.string, [None])
       concatenated, update_op = metrics.streaming_concat(values)
       sess.run(variables.local_variables_initializer())
@@ -6777,7 +6777,7 @@
           concatenated.eval())
 
   def testStreamingConcatMaxSize(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = math_ops.range(3)
       concatenated, update_op = metrics.streaming_concat(values, max_size=5)
       sess.run(variables.local_variables_initializer())
@@ -6794,7 +6794,7 @@
       self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
 
   def testStreamingConcat2D(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = array_ops.reshape(math_ops.range(3), (3, 1))
       concatenated, update_op = metrics.streaming_concat(values, axis=-1)
       sess.run(variables.local_variables_initializer())
@@ -6817,7 +6817,7 @@
           array_ops.placeholder(dtypes_lib.float32, [None, None]))
 
   def testStreamingConcatReset(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = array_ops.placeholder(dtypes_lib.int32, [None])
       concatenated, update_op = metrics.streaming_concat(values)
       sess.run(variables.local_variables_initializer())
@@ -6845,7 +6845,7 @@
         metrics.streaming_mean(values))
     self.assertEqual(len(value_tensors), 1)
     self.assertEqual(len(update_ops), 1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, update_ops[0].eval())
       self.assertEqual(1, value_tensors[0].eval())
@@ -6858,7 +6858,7 @@
         metrics.streaming_mean_squared_error(predictions, labels))
     self.assertEqual(len(value_tensors), 2)
     self.assertEqual(len(update_ops), 2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(2, update_ops[0].eval())
       self.assertEqual(4, update_ops[1].eval())
@@ -6879,7 +6879,7 @@
     self.assertEqual(2, len(names_to_values))
     self.assertEqual(2, len(names_to_updates))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(2, names_to_updates['m1'].eval())
       self.assertEqual(4, names_to_updates['m2'].eval())
@@ -6914,7 +6914,7 @@
     self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -6931,7 +6931,7 @@
       self.assertAlmostEqual(8.0, sess.run(result), 5)
 
   def testUpdateOpsReturnsCurrentValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -6952,7 +6952,7 @@
       self.assertAlmostEqual(8.0, sess.run(result), 5)
 
   def test1dWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6979,7 +6979,7 @@
       self.assertAlmostEqual(3.4, result.eval(), 5)
 
   def test1dWeightedValues_placeholders(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
       values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7001,7 +7001,7 @@
       self.assertAlmostEqual(3.4, result.eval(), 5)
 
   def test2dWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -7028,7 +7028,7 @@
       self.assertAlmostEqual(4.1, result.eval(), 5)
 
   def test2dWeightedValues_placeholders(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
       values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7101,7 +7101,7 @@
         (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
     kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -7135,7 +7135,7 @@
     for dtype in dtypes:
       for shape in shapes:
         for weight in weights:
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             predictions_tensor = constant_op.constant(
                 np.reshape(predictions, shape), dtype=dtype)
             labels_tensor = constant_op.constant(
@@ -7156,7 +7156,7 @@
     # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
     expect = 1.0
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7175,7 +7175,7 @@
     # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
     expect = -0.333333333333
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
       labels = constant_op.constant(labels)
       kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7193,7 +7193,7 @@
     #                          labels, predictions, sample_weight=weights)
     expect = 0.453466583385
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
       labels = constant_op.constant(labels)
       kappa, update_op = metrics.cohen_kappa(
@@ -7218,7 +7218,7 @@
     weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
     kappa, update_op = metrics.cohen_kappa(
         labels_t, predictions_t, num_classes, weights=weights_t)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       for idx in range(0, num_samples, batch_size):
@@ -7256,7 +7256,7 @@
   def testConditionalPackingOptimization(self):
     placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
     values, update_op = metric_ops.streaming_concat(placeholder)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       for feed in range(10):
         sess.run(update_op, feed_dict={placeholder: [feed]})
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
index e85ae7b..586c6c7 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
@@ -37,7 +37,7 @@
     expected_num_masks = 1
     expected_num_rows = 2 * self.dim
     expected_num_cols = 4 * self.dim
-    with self.test_session():
+    with self.cached_session():
       inputs = variables.Variable(
           random_ops.random_normal([self.batch_size, self.dim]))
       c = variables.Variable(
@@ -61,7 +61,7 @@
     expected_num_masks = 1
     expected_num_rows = 2 * self.dim
     expected_num_cols = 4 * self.dim
-    with self.test_session():
+    with self.cached_session():
       inputs = variables.Variable(
           random_ops.random_normal([self.batch_size, self.dim]))
       c = variables.Variable(
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 62996d1..9a9d480 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -31,9 +31,11 @@
         "kernels/nccl_manager.h",
         "kernels/nccl_ops.cc",
     ]),
-    deps = if_cuda([
+    deps = [] + if_cuda([
         "@local_config_nccl//:nccl",
         "//tensorflow/core:gpu_headers_lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:protos_all_proto_text",
     ]),
 )
 
@@ -57,32 +59,31 @@
         "notap",
     ],
     deps =
-        [
+        if_cuda([
+            "@local_config_nccl//:nccl",
             "//tensorflow/core:cuda",
             "//tensorflow/core:test",
             "//tensorflow/core:test_main",
             "//tensorflow/core:testlib",
-            "@local_config_nccl//:nccl",
-        ],
+        ]),
 )
 
 tf_kernel_library(
     name = "nccl_kernels",
-    srcs = [
+    srcs = if_cuda([
         "kernels/nccl_manager.cc",
         "kernels/nccl_manager.h",
         "kernels/nccl_ops.cc",
         "kernels/nccl_rewrite.cc",
-    ],
-    deps = [
+    ]),
+    deps = if_cuda([
+        "@local_config_nccl//:nccl",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
         "//tensorflow/core:gpu_headers_lib",
         "//tensorflow/core:lib",
-        "//tensorflow/core:proto_text",
         "//tensorflow/core:stream_executor",
-        "@local_config_nccl//:nccl",
-    ],
+    ]),
     alwayslink = 1,
 )
 
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
index 4676e93..06ff86e 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/graph/node_builder.h"
 
 namespace tensorflow {
diff --git a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
index cb69c72..d0955cb 100644
--- a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
+++ b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
@@ -31,7 +31,7 @@
   # tests in hyperplane_lsh_probes_test.cc already cover most of the LSH
   # functionality.
   def simple_batch_test(self):
-    with self.test_session():
+    with self.cached_session():
       hyperplanes = np.eye(4)
       points = np.array([[1.2, 0.5, -0.9, -1.0], [2.0, -3.0, 1.0, -1.5]])
       product = np.dot(points, hyperplanes)
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 93e5899..2e4d61d 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -159,8 +159,10 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:variables",
         "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
 
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index bbafd59..6c203e5 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -128,12 +128,14 @@
               = list(global_center_variable)[i]
       return local_var
     else:
-      return getter(
-          name,
-          trainable=trainable,
-          collections=collections,
-          *args,
-          **kwargs)
+      kwargs['trainable'] = trainable
+      kwargs['collections'] = collections
+      if ops.GraphKeys.LOCAL_VARIABLES in collections:
+        with ops.device(self._worker_device):
+          return getter(name, *args, **kwargs)
+      else:
+        return getter(name, *args, **kwargs)
+
 
 
 class ElasticAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1..f55209e 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -28,6 +28,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.training import adam
 
@@ -78,3 +79,36 @@
                                        lr * m_t_slice / denominator_slice,
                                        use_locking=self._use_locking)
     return control_flow_ops.group(var_update, m_t, v_t)
+
+  def _resource_apply_sparse(self, grad, var, indices):
+    beta1_power, beta2_power = self._get_beta_accumulators()
+    beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+    beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+    lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+    beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+    beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+    epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+    lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+    # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+    m = self.get_slot(var, "m")
+    m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+    m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+                                                                indices,
+                                                                m_t_slice)
+
+    # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+    v = self.get_slot(var, "v")
+    v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+                 (1 - beta2_t) * math_ops.square(grad))
+    v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+                                                                indices,
+                                                                v_t_slice)
+
+    # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+    var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+    var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+                                                               indices,
+                                                               var_slice)
+
+    return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index dc4c462..f08ffaa 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -19,14 +19,18 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.contrib.opt.python.training import lazy_adam_optimizer
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
@@ -49,9 +53,10 @@
   return param_t, m_t, v_t
 
 
-class AdamOptimizerTest(test.TestCase):
+class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
 
-  def testSparse(self):
+  @parameterized.parameters([False, True])
+  def testSparse(self, use_resource):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
       with self.cached_session():
         # Initialize variables for numpy implementation.
@@ -61,8 +66,13 @@
         var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
         grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
 
-        var0 = variables.Variable(var0_np)
-        var1 = variables.Variable(var1_np)
+        if use_resource:
+          var0 = resource_variable_ops.ResourceVariable(var0_np)
+          var1 = resource_variable_ops.ResourceVariable(var1_np)
+        else:
+          var0 = variables.Variable(var0_np)
+          var1 = variables.Variable(var1_np)
+
         grads0_np_indices = np.array([0, 1], dtype=np.int32)
         grads0 = ops.IndexedSlices(
             constant_op.constant(grads0_np),
@@ -94,12 +104,17 @@
           self.assertAllCloseAccordingToType(var0_np, var0.eval())
           self.assertAllCloseAccordingToType(var1_np, var1.eval())
 
-  def testSparseDevicePlacement(self):
+  @parameterized.parameters([False, True])
+  def testSparseDevicePlacement(self, use_resource):
     for index_dtype in [dtypes.int32, dtypes.int64]:
       with self.test_session(force_gpu=test.is_gpu_available()):
         # If a GPU is available, tests that all optimizer ops can be placed on
         # it (i.e. they have GPU kernels).
-        var = variables.Variable([[1.0], [2.0]])
+        if use_resource:
+          var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
+        else:
+          var = variables.Variable([[1.0], [2.0]])
+
         indices = constant_op.constant([0, 1], dtype=index_dtype)
         gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
         optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0)
@@ -107,13 +122,21 @@
         variables.global_variables_initializer().run()
         minimize_op.run()
 
-  def testSparseRepeatedIndices(self):
+  @parameterized.parameters([False, True])
+  def testSparseRepeatedIndices(self, use_resource):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
       with self.cached_session():
-        repeated_index_update_var = variables.Variable(
-            [[1.0], [2.0]], dtype=dtype)
-        aggregated_update_var = variables.Variable(
-            [[1.0], [2.0]], dtype=dtype)
+        if use_resource:
+          repeated_index_update_var = resource_variable_ops.ResourceVariable(
+              [[1.0], [2.0]], dtype=dtype)
+          aggregated_update_var = resource_variable_ops.ResourceVariable(
+              [[1.0], [2.0]], dtype=dtype)
+        else:
+          repeated_index_update_var = variables.Variable(
+              [[1.0], [2.0]], dtype=dtype)
+          aggregated_update_var = variables.Variable(
+              [[1.0], [2.0]], dtype=dtype)
+
         grad_repeated_index = ops.IndexedSlices(
             constant_op.constant(
                 [0.1, 0.1], shape=[2, 1], dtype=dtype),
@@ -139,6 +162,204 @@
           self.assertAllClose(aggregated_update_var.eval(),
                               repeated_index_update_var.eval())
 
+  def doTestBasic(self, use_resource=False, use_callable_params=False):
+    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+      with self.session(graph=ops.Graph()):
+        # Initialize variables for numpy implementation.
+        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+        if use_resource:
+          var0 = resource_variable_ops.ResourceVariable(
+              var0_np, name="var0_%d" % i)
+          var1 = resource_variable_ops.ResourceVariable(
+              var1_np, name="var1_%d" % i)
+        else:
+          var0 = variables.Variable(var0_np)
+          var1 = variables.Variable(var1_np)
+        grads0 = constant_op.constant(grads0_np)
+        grads1 = constant_op.constant(grads1_np)
+
+        learning_rate = lambda: 0.001
+        beta1 = lambda: 0.9
+        beta2 = lambda: 0.999
+        epsilon = lambda: 1e-8
+        if not use_callable_params:
+          learning_rate = learning_rate()
+          beta1 = beta1()
+          beta2 = beta2()
+          epsilon = epsilon()
+
+        opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate)
+        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+        opt_variables = opt.variables()
+        beta1_power, beta2_power = opt._get_beta_accumulators()
+        self.assertIsNotNone(beta1_power)
+        self.assertIsNotNone(beta2_power is not None)
+        self.assertIn(beta1_power, opt_variables)
+        self.assertIn(beta2_power, opt_variables)
+
+        if not context.executing_eagerly():
+          with ops.Graph().as_default():
+            # Shouldn't return non-slot variables from other graphs.
+            self.assertEqual(0, len(opt.variables()))
+          self.evaluate(variables.global_variables_initializer())
+          # Fetch params to validate initial values
+          self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+          self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+        beta1_power, beta2_power = opt._get_beta_accumulators()
+
+        # Run 3 steps of Adam
+        for t in range(1, 4):
+          if not context.executing_eagerly():
+            self.evaluate(update)
+          elif t > 1:
+            opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+          self.assertAllCloseAccordingToType(0.9**(t + 1),
+                                             self.evaluate(beta1_power))
+          self.assertAllCloseAccordingToType(0.999**(t + 1),
+                                             self.evaluate(beta2_power))
+
+          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+          # Validate updated params
+          self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+          self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+          if use_resource:
+            self.assertEqual("var0_%d/Adam:0" % (i,),
+                             opt.get_slot(var=var0, name="m").name)
+
+  def testBasic(self):
+    with self.test_session():
+      self.doTestBasic(use_resource=False)
+
+  @test_util.run_in_graph_and_eager_modes(reset_test=True)
+  def testResourceBasic(self):
+    self.doTestBasic(use_resource=True)
+
+  def testBasicCallableParams(self):
+    with context.eager_mode():
+      self.doTestBasic(use_resource=True, use_callable_params=True)
+
+  def testTensorLearningRate(self):
+    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+      with self.test_session():
+        # Initialize variables for numpy implementation.
+        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+        var0 = variables.Variable(var0_np)
+        var1 = variables.Variable(var1_np)
+        grads0 = constant_op.constant(grads0_np)
+        grads1 = constant_op.constant(grads1_np)
+        opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001))
+        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+        variables.global_variables_initializer().run()
+
+        # Fetch params to validate initial values
+        self.assertAllClose([1.0, 2.0], var0.eval())
+        self.assertAllClose([3.0, 4.0], var1.eval())
+
+        beta1_power, beta2_power = opt._get_beta_accumulators()
+
+        # Run 3 steps of Adam
+        for t in range(1, 4):
+          self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+          self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+          update.run()
+
+          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+          # Validate updated params
+          self.assertAllCloseAccordingToType(var0_np, var0.eval())
+          self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+  def testSharing(self):
+    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+      with self.test_session():
+        # Initialize variables for numpy implementation.
+        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+        var0 = variables.Variable(var0_np)
+        var1 = variables.Variable(var1_np)
+        grads0 = constant_op.constant(grads0_np)
+        grads1 = constant_op.constant(grads1_np)
+        opt = lazy_adam_optimizer.LazyAdamOptimizer()
+        update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+        update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+        variables.global_variables_initializer().run()
+
+        beta1_power, beta2_power = opt._get_beta_accumulators()
+
+        # Fetch params to validate initial values
+        self.assertAllClose([1.0, 2.0], var0.eval())
+        self.assertAllClose([3.0, 4.0], var1.eval())
+
+        # Run 3 steps of intertwined Adam1 and Adam2.
+        for t in range(1, 4):
+          self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+          self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+          if t % 2 == 0:
+            update1.run()
+          else:
+            update2.run()
+
+          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+          # Validate updated params
+          self.assertAllCloseAccordingToType(var0_np, var0.eval())
+          self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+  def testTwoSessions(self):
+    optimizer = lazy_adam_optimizer.LazyAdamOptimizer()
+
+    with context.eager_mode():
+      var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+      grads0 = constant_op.constant(np.array([0.1, 0.1]))
+      optimizer.apply_gradients([(grads0, var0)])
+
+    g = ops.Graph()
+    with g.as_default():
+      with self.session(graph=g):
+        var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+        grads0 = constant_op.constant(np.array([0.1, 0.1]))
+        optimizer.apply_gradients([(grads0, var0)])
+
+    gg = ops.Graph()
+    with gg.as_default():
+      with self.session(graph=gg):
+        var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+        grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+        # If the optimizer saves any state not keyed by graph the following line
+        # fails.
+        optimizer.apply_gradients([(grads0, var0)])
+
+  def testSlotsUniqueEager(self):
+    with context.eager_mode():
+      v1 = resource_variable_ops.ResourceVariable(1.)
+      v2 = resource_variable_ops.ResourceVariable(1.)
+      opt = lazy_adam_optimizer.LazyAdamOptimizer(1.)
+      opt.minimize(lambda: v1 + v2)
+      # There should be two non-slot variables, and two unique slot variables
+      # for v1 and v2 respectively.
+      self.assertEqual(6, len(set(opt.variables())))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index b6b10e5..746df77 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -89,7 +89,13 @@
       self._local_2_global[local_var] = global_variable
       return local_var
     else:
-      return getter(name, trainable, collections, *args, **kwargs)
+      kwargs['trainable'] = trainable
+      kwargs['collections'] = collections
+      if ops.GraphKeys.LOCAL_VARIABLES in collections:
+        with ops.device(self._worker_device):
+          return getter(name, *args, **kwargs)
+      else:
+        return getter(name, *args, **kwargs)
 
 
 class ModelAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index 3acd940..b1fc50a 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -80,28 +80,28 @@
         var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
         var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
 
-      with ops.device("/job:worker/task:" + str(worker_id)):
-        if worker_id == 0:
-          grads_0 = constant_op.constant(-1.0)
-          grads_1 = constant_op.constant(-1.0)
-        else:
-          grads_0 = constant_op.constant(-2.0)
-          grads_1 = constant_op.constant(-2.0)
-        sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
-        opt = model_average_optimizer.ModelAverageOptimizer(
-            opt=sgd_opt,
-            num_worker=num_workers,
-            ma_custom_getter=ma_coustom,
-            is_chief=is_chief,
-            interval_steps=steps)
-        train_op = [
-            opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
-                                global_step)
-        ]
-      easgd_hook = opt.make_session_run_hook()
+        with ops.device("/job:worker/task:" + str(worker_id)):
+          if worker_id == 0:
+            grads_0 = constant_op.constant(-1.0)
+            grads_1 = constant_op.constant(-1.0)
+          else:
+            grads_0 = constant_op.constant(-2.0)
+            grads_1 = constant_op.constant(-2.0)
+          sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+          opt = model_average_optimizer.ModelAverageOptimizer(
+              opt=sgd_opt,
+              num_worker=num_workers,
+              ma_custom_getter=ma_coustom,
+              is_chief=is_chief,
+              interval_steps=steps)
+          train_op = [
+              opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+                                  global_step)
+          ]
+      ma_hook = opt.make_session_run_hook()
       # Creates MonitoredSession
       sess = training.MonitoredTrainingSession(
-          workers[worker_id].target, hooks=[easgd_hook])
+          workers[worker_id].target, hooks=[ma_hook])
 
     sessions.append(sess)
     graphs.append(graph)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index c333d1e..25ec475 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -64,18 +64,17 @@
 
   def _create_vars(self, var_list, state):
     for v in var_list:
-      # TODO(isaprykin): Delete colocate_with(v) from other optimizers and
-      # confirm that colocation will happen anyway.
       dtype = v.dtype.base_dtype
       if v.get_shape().is_fully_defined():
         init = init_ops.constant_initializer(self._initial_accumulator_value,
                                              dtype=dtype)
       else:
-        # Use a Tensor instead of initializer if variable does not have static
-        # shape.
-        init_constant = gen_array_ops.fill(
-            array_ops.shape(v), self._initial_accumulator_value)
-        init = math_ops.cast(init_constant, dtype)
+        def init(v=v, dtype=dtype):
+          # Use a Tensor instead of initializer if variable does not have
+          # static shape.
+          init_constant = gen_array_ops.fill(array_ops.shape(v),
+                                             self._initial_accumulator_value)
+          return math_ops.cast(init_constant, dtype)
       state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
                                          "accumulator")
 
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f6ecaba..6af59dc 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -214,7 +214,8 @@
     # with that Tensor cast to that dtype.
     with ops.init_scope():
       self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
-                     for name, (dynamic, value) in hyper.items() if not dynamic}
+                     for name, (dynamic, value) in sorted(hyper.items())
+                     if not dynamic}
     self._slots = {}
     self._non_slot_dict = {}
     # Extra state to help Optimizers implement Checkpointable. Holds information
@@ -231,7 +232,8 @@
     ret._deferred_dependencies = self._deferred_dependencies
     ret._deferred_slot_restorations = self._deferred_slot_restorations
     ret._hyper = {name: {None: _resolve(value, name)}
-                  for name, (dynamic, value) in hyper.items() if dynamic}
+                  for name, (dynamic, value) in sorted(hyper.items())
+                  if dynamic}
     ret._hyper.update(self._hyper)
     ret._non_slot_devices = non_slot_devices
     ret._distribution = distribution
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 31a6fe1..9a19502 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -38,7 +38,7 @@
     desired_shape = numpy.array([6, None])
     output_tensor = input_tensor.reshape((6, 2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       result = periodic_resample(input_tensor, desired_shape).eval()
       self.assertAllEqual(result, output_tensor)
@@ -49,7 +49,7 @@
     desired_shape = numpy.array([5, None])
     output_tensor = input_tensor.reshape((6, 2))[:-1]
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       result = periodic_resample(input_tensor, desired_shape).eval()
       self.assertAllEqual(result, output_tensor)
@@ -63,7 +63,7 @@
                                                            [15]]])
 
     # NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       result = periodic_resample(input_tensor, desired_shape).eval()
       # input_tensor[0, 0, 0] == result[0, 0, 0]
@@ -88,14 +88,14 @@
           [[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
 
     # NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       result = periodic_resample(input_tensor, desired_shape).eval()
       self.assertAllEqual(result, output_tensor)
 
   def testPeriodicResampleErrors(self):
     input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesWithPredicateMatch(
           errors_impl.InvalidArgumentError,
           'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -109,7 +109,7 @@
     desired_shape = numpy.array([4, 4, None])
     result_shape = (4, 4, 1)
     input_shape = (2, 2, 4)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32, shape=input_shape)
       output = periodic_resample(x, desired_shape)
       error = gradient_checker.compute_gradient_error(
@@ -117,7 +117,7 @@
       self.assertLess(error, 1e-4)
 
   def testPeriodicResampleShapeInference(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Case 1: output shape can be fully inferreed.
       x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
       output = periodic_resample(x, [4, 4, None])
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index 95da6d0..0339939 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -23,7 +23,6 @@
 
 from tensorflow.contrib.predictor import predictor
 from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
 from tensorflow.python.client import session
 from tensorflow.python.framework import ops
 from tensorflow.python.saved_model import loader
@@ -68,23 +67,19 @@
   metagraph_def = get_meta_graph_def(export_dir, tags)
 
   try:
-    signature_def = signature_def_utils.get_signature_def_by_key(
-        metagraph_def,
+    signature_def = metagraph_def.signature_def[signature_def_key]
+  except KeyError as e:
+    formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
         signature_def_key)
-  except ValueError as e:
     try:
-      formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
-          signature_def_key)
-      signature_def = signature_def_utils.get_signature_def_by_key(
-          metagraph_def, formatted_key)
-
-      logging.warning('Could not find signature def "%s". '
-                      'Using "%s" instead', signature_def_key, formatted_key)
-    except ValueError:
+      signature_def = metagraph_def.signature_def[formatted_key]
+    except KeyError:
       raise ValueError(
           'Got signature_def_key "{}". Available signatures are {}. '
           'Original error:\n{}'.format(
               signature_def_key, list(metagraph_def.signature_def), e))
+    logging.warning('Could not find signature def "%s". '
+                    'Using "%s" instead', signature_def_key, formatted_key)
   return signature_def
 
 
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 499fec4..c59f667 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@
         ":common",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:session",
         "//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@
         ":common",
         ":graph_matcher",
         ":input_to_ops",
-        "//tensorflow/contrib/graph_editor:graph_editor_py",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@
         ":graph_matcher",
         ":input_to_ops",
         ":quant_ops",
-        "//tensorflow/contrib/graph_editor:graph_editor_py",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e1..b27117d 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@
     return s[len(prefix):]
   else:
     return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+  """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+  Args:
+    t0: a tf.Tensor.
+    t1: a tf.Tensor.
+    can_modify: iterable of operations which can be modified. Any operation
+      outside within_ops will be left untouched by this function.
+
+  Returns:
+    The number of individual modifications made by the function.
+  """
+  nb_update_inputs = 0
+  consumers = t1.consumers()
+  if can_modify is not None:
+    consumers = [c for c in consumers if c in can_modify]
+  consumers_indices = {}
+  for c in consumers:
+    consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+  for c in consumers:
+    for i in consumers_indices[c]:
+      c._update_input(i, t0)  # pylint: disable=protected-access
+      nb_update_inputs += 1
+  return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2..2b26302 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@
 
 from tensorflow.contrib.quantize.python import common
 from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@
       _, step_val = sess.run([b, quantization_step_tensor])
       self.assertEqual(step_val, 2)
 
+  def testRerouteTensor(self):
+    a = constant_op.constant(1, name='a')
+    b = constant_op.constant(2, name='b')
+    c = constant_op.constant(3, name='c')
+    d = constant_op.constant(4, name='d')
+
+    add_ac = math_ops.add(a, c)
+    add_ad = math_ops.add(a, d)
+
+    # Ensure that before rerouting the inputs are what we think.
+    self._CheckOpHasInputs(add_ac.op, [a, c])
+    self._CheckOpHasInputs(add_ad.op, [a, d])
+
+    # references to tensor a should be replaced with b for all ops in
+    # can_modify. This means add_ac will be changed but add_ad will not.
+    common.RerouteTensor(b, a, can_modify=[add_ac.op])
+    self._CheckOpHasInputs(add_ac.op, [b, c])
+    self._CheckOpHasInputs(add_ad.op, [a, d])
+
+  def _CheckOpHasInputs(self, op, inputs):
+    for i in inputs:
+      self.assertIn(i, op.inputs)
+
 
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179b..2971b28 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@
 from __future__ import print_function
 
 import re
-from tensorflow.contrib import graph_editor
 from tensorflow.contrib.quantize.python import common
 from tensorflow.contrib.quantize.python import graph_matcher
 from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@
       bias_add_tensor = math_ops.add(
           new_layer_tensor, bias_tensor, name='add_fold')
 
-      nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
-                                                     match.output_tensor)
+      nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+                                                  match.output_tensor)
       if nodes_modified_count == 0:
         raise ValueError('Folding batch norms failed, %s had no outputs.' %
                          match.output_tensor.name)
@@ -370,8 +369,9 @@
         lambda: match.bn_decay_mean_tensor,
         name='freeze_moving_mean')
 
-    graph_editor.reroute_ts(
-        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+    common.RerouteTensor(
+        bn_decay_mean_out,
+        match.bn_decay_mean_tensor,
         can_modify=bn_decay_mean_consumers)
 
     bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@
         lambda: bn_decay_zero,
         lambda: match.bn_decay_var_tensor,
         name='freeze_moving_var')
-    graph_editor.reroute_ts(
-        [bn_decay_var_out], [match.bn_decay_var_tensor],
+    common.RerouteTensor(
+        bn_decay_var_out,
+        match.bn_decay_var_tensor,
         can_modify=bn_decay_var_consumers)
 
     correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@
 
     activation = common.GetEndpointActivationOp(graph, bn)
     if activation:
-      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
-                                                     [original_op.outputs[0]],
-                                                     can_modify=[activation])
+      nodes_modified_count = common.RerouteTensor(
+          folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
       if nodes_modified_count != 1:
         raise ValueError('Unexpected inputs to op: %s' % activation.name)
       continue
@@ -497,9 +497,8 @@
     # operations instead of Relu* above.
     add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
     add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
-    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
-                                                   [original_op.outputs[0]],
-                                                   can_modify=[add_bypass])
+    nodes_modified_count = common.RerouteTensor(
+        folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
     if nodes_modified_count != 1:
       raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
 
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73..e88db0a 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@
 from __future__ import print_function
 
 import re
-from tensorflow.contrib import graph_editor
 from tensorflow.contrib.quantize.python import common
 from tensorflow.contrib.quantize.python import graph_matcher
 from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@
         name=name_prefix + '/delayed_quant')
 
   if consumers:
-    tensors_modified_count = graph_editor.reroute_ts(
-        [quant], [inputs], can_modify=consumers)
+    tensors_modified_count = common.RerouteTensor(
+        quant, inputs, can_modify=consumers)
     # Some operations can have multiple output tensors going to the same
     # consumer. Since consumers is a set, we need to ensure that
     # tensors_modified_count is greater than or equal to the length of the set
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
index 00fbd4f..aea80a5 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -56,7 +56,7 @@
           x_power=state.x_power * theta.x)
       return next_state, []
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       theta = _PolyTheta(x=array_ops.constant(2.0))
       state = _PolyState(
           value=array_ops.constant(0.0),
@@ -142,7 +142,7 @@
 
   def _ParameterizedTestElman(self, seqlen, use_grad):
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       random_seed.set_random_seed(342462)
 
       batch = 3
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 5874245..4e67d80 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -212,6 +212,7 @@
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
+    tags = ["noasan"],
 )
 
 tf_custom_op_library(
@@ -279,7 +280,10 @@
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
-    tags = ["no_oss"],
+    tags = [
+        "no_oss",
+        "noasan",
+    ],
 )
 
 tf_cc_test(
@@ -287,6 +291,7 @@
     size = "small",
     srcs = ["ops/gru_ops_test.cc"],
     data = [":python/ops/_gru_ops.so"],
+    tags = ["noasan"],
     # We must ensure that the dependencies can be dynamically linked since
     # the shared library must be able to use core:framework.
     # linkstatic = tf_kernel_tests_linkstatic(),
@@ -306,6 +311,7 @@
     size = "small",
     srcs = ["ops/lstm_ops_test.cc"],
     data = [":python/ops/_lstm_ops.so"],
+    tags = ["noasan"],
     # We must ensure that the dependencies can be dynamically linked since
     # the shared library must be able to use core:framework.
     # linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 15ce9d1..be0306c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@
 class RNNCellTest(test.TestCase):
 
   def testLinear(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(1.0)):
         x = array_ops.zeros([1, 2])
@@ -69,7 +69,7 @@
         self.assertEqual(len(variables_lib.trainable_variables()), 2)
 
   def testBasicRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -89,7 +89,7 @@
         self.assertEqual(res[0].shape, (1, 2))
 
   def testBasicRNNCellNotTrainable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       def not_trainable_getter(getter, *args, **kwargs):
         kwargs["trainable"] = False
@@ -116,7 +116,7 @@
         self.assertEqual(res[0].shape, (1, 2))
 
   def testIndRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -137,7 +137,7 @@
         self.assertEqual(res[0].shape, (1, 2))
 
   def testGRUCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -165,7 +165,7 @@
         self.assertAllClose(res[0], [[0.156736, 0.156736]])
 
   def testIndyGRUCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -193,7 +193,7 @@
         self.assertAllClose(res[0], [[0.155127, 0.157328]])
 
   def testSRUCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -208,7 +208,7 @@
         self.assertAllClose(res[0], [[0.509682, 0.509682]])
 
   def testSRUCellWithDiffSize(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -288,7 +288,7 @@
 
   def testBasicLSTMCellDimension0Error(self):
     """Tests that dimension 0 in both(x and m) shape must be equal."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         num_units = 2
@@ -309,7 +309,7 @@
 
   def testBasicLSTMCellStateSizeError(self):
     """Tests that state_size must be num_units * 2."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         num_units = 2
@@ -329,7 +329,7 @@
               })
 
   def testBasicLSTMCellStateTupleType(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -360,7 +360,7 @@
         self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
 
   def testBasicLSTMCellWithStateTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -459,7 +459,7 @@
           self.assertEqual(len(res), 2)
 
   def testLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 8
       num_proj = 6
       state_size = num_units + num_proj
@@ -494,7 +494,7 @@
               float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
 
   def testLSTMCellVariables(self):
-    with self.test_session():
+    with self.cached_session():
       num_units = 8
       num_proj = 6
       state_size = num_units + num_proj
@@ -517,7 +517,7 @@
                         "root/lstm_cell/projection/kernel")
 
   def testLSTMCellLayerNorm(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 2
       num_proj = 3
       batch_size = 1
@@ -562,22 +562,21 @@
         rnn_cell_impl.DropoutWrapper,
         rnn_cell_impl.ResidualWrapper,
         lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
-      with self.test_session():
-        cell = rnn_cell_impl.BasicRNNCell(1)
-        wrapper = wrapper_type(cell)
-        wrapper(array_ops.ones([1, 1]),
-                state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
-        self.evaluate([v.initializer for v in cell.variables])
-        checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
-        prefix = os.path.join(self.get_temp_dir(), "ckpt")
-        self.evaluate(cell._bias.assign([40.]))
-        save_path = checkpoint.save(prefix)
-        self.evaluate(cell._bias.assign([0.]))
-        checkpoint.restore(save_path).assert_consumed().run_restore_ops()
-        self.assertAllEqual([40.], self.evaluate(cell._bias))
+      cell = rnn_cell_impl.BasicRNNCell(1)
+      wrapper = wrapper_type(cell)
+      wrapper(array_ops.ones([1, 1]),
+              state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+      self.evaluate([v.initializer for v in cell.variables])
+      checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+      prefix = os.path.join(self.get_temp_dir(), "ckpt")
+      self.evaluate(cell._bias.assign([40.]))
+      save_path = checkpoint.save(prefix)
+      self.evaluate(cell._bias.assign([0.]))
+      checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+      self.assertAllEqual([40.], self.evaluate(cell._bias))
 
   def testOutputProjectionWrapper(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -594,7 +593,7 @@
         self.assertAllClose(res[0], [[0.231907, 0.231907]])
 
   def testInputProjectionWrapper(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -612,7 +611,7 @@
         self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
 
   def testResidualWrapper(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -638,7 +637,7 @@
         self.assertAllClose(res[2], res[3])
 
   def testResidualWrapperWithSlice(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 5])
@@ -716,7 +715,7 @@
       self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
 
   def testEmbeddingWrapper(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 1], dtype=dtypes.int32)
@@ -735,7 +734,7 @@
         self.assertAllClose(res[0], [[0.17139, 0.17139]])
 
   def testEmbeddingWrapperWithDynamicRnn(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope("root"):
         inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
         input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
@@ -753,7 +752,7 @@
         sess.run(outputs)
 
   def testMultiRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -770,7 +769,7 @@
         self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
 
   def testMultiRNNCellWithStateTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -809,7 +808,7 @@
                           time_steps=None,
                           parallel_iterations=None,
                           **kwargs):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         if batch_size is None and time_steps is None:
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index aa4562b..bf699db 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1906,7 +1906,7 @@
     state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
     out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables_lib.global_variables_initializer())
       sess.run(variables_lib.local_variables_initializer())
 
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e..8d34b9e 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@
   def testBasicRNNFusedWrapper(self):
     """This test checks that using a wrapper for BasicRNN works as expected."""
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       initializer = init_ops.random_uniform_initializer(
           -0.01, 0.01, seed=19890212)
       cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@
         self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
 
   def testTimeReversedFusedRNN(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       initializer = init_ops.random_uniform_initializer(
           -0.01, 0.01, seed=19890213)
       fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 2df8f0e..6689664 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@
 class RNNCellTest(test.TestCase):
 
   def testCoupledInputForgetGateLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 2
       state_size = num_units * 2
       batch_size = 3
@@ -81,7 +81,7 @@
         self.assertAllClose(res[1], expected_state)
 
   def testTimeFreqLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 8
       state_size = num_units * 2
       batch_size = 3
@@ -120,7 +120,7 @@
               float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
 
   def testGridLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 8
       batch_size = 3
       input_size = 4
@@ -166,7 +166,7 @@
                                   .state_f00_b00_c[i, :]))) > 1e-6)
 
   def testGridLSTMCellWithFrequencyBlocks(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 8
       batch_size = 3
       feature_size = 2
@@ -248,7 +248,7 @@
         ]],
         dtype=np.float32)
     for state_is_tuple in [False, True]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         with variable_scope.variable_scope(
             "state_is_tuple" + str(state_is_tuple),
             initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@
             self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
 
   def testBidirectionGridLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 2
       batch_size = 3
       input_size = 4
@@ -374,7 +374,7 @@
         self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
 
   def testBidirectionGridLSTMCellWithSliceOffset(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 2
       batch_size = 3
       input_size = 4
@@ -487,7 +487,7 @@
     input_size = 4
     for state_is_tuple in [False, True]:
       with ops.Graph().as_default():
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           with variable_scope.variable_scope(
               "state_is_tuple_" + str(state_is_tuple)):
             lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@
     batch_size = 3
     for state_is_tuple in [False, True]:
       with ops.Graph().as_default():
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           with variable_scope.variable_scope(
               "state_is_tuple_" + str(state_is_tuple)):
             lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@
         0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
         0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "nas_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@
         0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
         1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@
         [[0.13752282, 0.13752282], [0.10545051, 0.10545051],
          [0.10074195, 0.10074195]],
         dtype=np.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
         cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@
         [[2.00431061, 2.00431061], [4.00060606, 4.00060606],
          [6.00008249, 6.00008249]],
         dtype=np.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "intersection_rnn_cell_test",
           initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@
       cell(inputs, init_state)
 
   def testPhasedLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_units = 2
       batch_size = 3
       input_size = 4
@@ -874,7 +874,7 @@
         self.assertAllClose(res[1].h, expected_state_h)
 
   def testConv1DLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shape = [2, 1]
       filter_size = [3]
       num_features = 1
@@ -907,7 +907,7 @@
         self.assertAllClose(res[1].h, expected_state_h)
 
   def testConv2DLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shape = [2, 2, 1]
       filter_size = [3, 3]
       num_features = 1
@@ -948,7 +948,7 @@
         self.assertAllClose(res[1].h, expected_state_h)
 
   def testConv3DLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shape = [2, 2, 2, 1]
       filter_size = [3, 3, 3]
       num_features = 1
@@ -999,7 +999,7 @@
         self.assertAllClose(res[1].h, expected_state_h)
 
   def testHighwayWrapper(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "base_cell", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@
 
     # Try with input dimension equal to num_units or not.
     for num_inputs in [num_units, num_units + number_of_groups]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         with variable_scope.variable_scope(
             "root1_%d" % num_inputs,
             initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@
 
     # Try with num_inputs equal to or not equal to num_units.
     for num_inputs in [num_units, num_units + number_of_groups]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         with variable_scope.variable_scope(
             "root2_%d" % num_inputs,
             initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@
     batch_size = 2
     num_units = 4
     number_of_groups = 2
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope(
           "glstm_failure", initializer=init_ops.constant_initializer(0.5)):
         gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@
   # NOTE: all the values in the current test case have been calculated.
 
   def testBasicLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@
 
   def testBasicLSTMCellWithoutNorm(self):
     """Tests that BasicLSTMCell with layer_norm=False."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@
         self.assertAllClose(res[1].h, expected_h, 1e-5)
 
   def testBasicLSTMCellWithStateTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@
 
   def testBasicLSTMCellWithStateTupleLayerNorm(self):
     """The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@
     num_units = 5
     allowed_low = [1, 2, 3]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "other", initializer=init_ops.constant_initializer(1)):
         x = array_ops.zeros([1, 5])
@@ -1479,7 +1479,7 @@
       self.assertAllClose(xla_g, non_xla_g, atol=atol)
 
   def testMultiRNNCellWithStateTuple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
         x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@
   def _cell_output(self, cell):
     """Calculates cell output."""
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       init = init_ops.constant_initializer(0.5)
       with variable_scope.variable_scope("root",
                                          initializer=init):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index f74c95f..06c4816 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -97,10 +97,10 @@
 
   The default non-peephole implementation is based on:
 
-    http://www.bioinf.jku.at/publications/older/2604.pdf
+    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
 
-  S. Hochreiter and J. Schmidhuber.
-  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
 
   The peephole implementation is based on:
 
@@ -2448,10 +2448,10 @@
 
   The default non-peephole implementation is based on:
 
-    http://www.bioinf.jku.at/publications/older/2604.pdf
+    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
 
-  S. Hochreiter and J. Schmidhuber.
-  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
 
   The peephole implementation is based on:
 
@@ -2802,9 +2802,11 @@
     Training of Deep Neural Networks
 
     The default LSTM implementation based on:
-    http://www.bioinf.jku.at/publications/older/2604.pdf
-    S. Hochreiter and J. Schmidhuber.
-    "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+      https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+    Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+    "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
 
     The class uses optional peephole connections, optional cell clipping
     and an optional projection layer.
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index b897224..4ca5274 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -78,23 +78,6 @@
     ],
 )
 
-py_test(
-    name = "signature_def_utils_test",
-    size = "small",
-    srcs = ["python/saved_model/signature_def_utils_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":saved_model_py",
-        "//tensorflow/core:protos_all_py",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework_for_generated_wrappers",
-        "//tensorflow/python/saved_model:signature_constants",
-        "//tensorflow/python/saved_model:signature_def_utils",
-        "//tensorflow/python/saved_model:utils",
-    ],
-)
-
 py_library(
     name = "keras_saved_model",
     srcs = ["python/saved_model/keras_saved_model.py"],
@@ -123,6 +106,7 @@
     size = "medium",
     srcs = ["python/saved_model/keras_saved_model_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["notsan"],
     deps = [
         ":keras_saved_model",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 074dc65..ac95e38 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -25,13 +25,11 @@
 
 # pylint: disable=unused-import,wildcard-import,line-too-long
 from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
-from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
 # pylint: enable=unused-import,wildcard-import,line-too-long
 
 from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = [
-    "get_signature_def_by_key",
     "load_keras_model",
     "save_keras_model"]
 
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index 3c616c5..ea4d41d 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -30,6 +30,7 @@
     hdrs = ["signature_def_utils.h"],
     visibility = ["//visibility:public"],
     deps = [
+        "//tensorflow/cc/saved_model:signature_constants",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_proto_parsing",
@@ -42,6 +43,7 @@
     srcs = ["signature_def_utils_test.cc"],
     deps = [
         ":signature_def_utils",
+        "//tensorflow/cc/saved_model:signature_constants",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
index a45908d..e87e497 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -15,6 +15,8 @@
 
 #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
 
+#include "tensorflow/cc/saved_model/signature_constants.h"
+#include "tensorflow/core/framework/types.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/protobuf.h"
@@ -33,6 +35,79 @@
   *value = &it->second;
   return Status::OK();
 }
+
+// Looks up the TensorInfo for the given key in the given map and verifies that
+// its datatype matches the given correct datatype.
+bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
+                                 const string& key, DataType correct_dtype) {
+  const TensorInfo* tensor_info;
+  const Status& status = FindInProtobufMap("", map, key, &tensor_info);
+  if (!status.ok()) {
+    return false;
+  }
+  if (tensor_info->dtype() != correct_dtype) {
+    return false;
+  }
+  return true;
+}
+
+bool IsValidPredictSignature(const SignatureDef& signature_def) {
+  if (signature_def.method_name() != kPredictMethodName) {
+    return false;
+  }
+  if (signature_def.inputs().empty()) {
+    return false;
+  }
+  if (signature_def.outputs().empty()) {
+    return false;
+  }
+  return true;
+}
+
+bool IsValidRegressionSignature(const SignatureDef& signature_def) {
+  if (signature_def.method_name() != kRegressMethodName) {
+    return false;
+  }
+  if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
+                                   DT_STRING)) {
+    return false;
+  }
+  if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
+                                   DT_FLOAT)) {
+    return false;
+  }
+  return true;
+}
+
+bool IsValidClassificationSignature(const SignatureDef& signature_def) {
+  if (signature_def.method_name() != kClassifyMethodName) {
+    return false;
+  }
+  if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
+                                   DT_STRING)) {
+    return false;
+  }
+  if (signature_def.outputs().empty()) {
+    return false;
+  }
+  for (auto const& output : signature_def.outputs()) {
+    const string& key = output.first;
+    const TensorInfo& tensor_info = output.second;
+    if (key == kClassifyOutputClasses) {
+      if (tensor_info.dtype() != DT_STRING) {
+        return false;
+      }
+    } else if (key == kClassifyOutputScores) {
+      if (tensor_info.dtype() != DT_FLOAT) {
+        return false;
+      }
+    } else {
+      return false;
+    }
+  }
+  return true;
+}
+
 }  // namespace
 
 Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
@@ -74,4 +149,10 @@
   return Status::OK();
 }
 
+bool IsValidSignature(const SignatureDef& signature_def) {
+  return IsValidClassificationSignature(signature_def) ||
+         IsValidRegressionSignature(signature_def) ||
+         IsValidPredictSignature(signature_def);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index b732cdd..bb24faa 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -64,6 +64,9 @@
 Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
                                  const string& tensor_info_key, string* name);
 
+// Determine whether a SignatureDef can be served by TensorFlow Serving.
+bool IsValidSignature(const SignatureDef& signature_def);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
index a063e95..c743112 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
 
+#include "tensorflow/cc/saved_model/signature_constants.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/test.h"
@@ -22,7 +23,7 @@
 
 namespace tensorflow {
 
-class SignatureDefUtilsTest : public ::testing::Test {
+class FindByKeyTest : public ::testing::Test {
  protected:
   MetaGraphDef MakeSampleMetaGraphDef() {
     MetaGraphDef result;
@@ -32,13 +33,23 @@
     return result;
   }
 
+  void SetInputNameForKey(const string& key, const string& name,
+                          SignatureDef* signature_def) {
+    (*signature_def->mutable_inputs())[key].set_name(name);
+  }
+
+  void SetOutputNameForKey(const string& key, const string& name,
+                           SignatureDef* signature_def) {
+    (*signature_def->mutable_outputs())[key].set_name(name);
+  }
+
   SignatureDef MakeSampleSignatureDef() {
     SignatureDef result;
     result.set_method_name(kMethodName);
-    (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
-    (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
-    (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
-    (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+    SetInputNameForKey(kInput1Key, kInput1Name, &result);
+    SetInputNameForKey(kInput2Key, kInput2Name, &result);
+    SetOutputNameForKey(kOutput1Key, kOutput1Name, &result);
+    SetOutputNameForKey(kOutput2Key, kOutput2Name, &result);
     return result;
   }
 
@@ -54,7 +65,7 @@
   const string kOutput2Name = "output_two";
 };
 
-TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+TEST_F(FindByKeyTest, FindSignatureDefByKey) {
   const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
   const SignatureDef* signature_def;
   // Succeeds for an existing signature.
@@ -67,7 +78,7 @@
           .ok());
 }
 
-TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindInputTensorNameByKey) {
   const SignatureDef signature_def = MakeSampleSignatureDef();
   string name;
   // Succeeds for an existing input.
@@ -78,7 +89,7 @@
       FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
 }
 
-TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindOutputTensorNameByKey) {
   const SignatureDef signature_def = MakeSampleSignatureDef();
   string name;
   // Succeeds for an existing output.
@@ -89,4 +100,100 @@
       FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
 }
 
+class IsValidSignatureTest : public ::testing::Test {
+ protected:
+  void SetInputDataTypeForKey(const string& key, DataType dtype) {
+    (*signature_def_.mutable_inputs())[key].set_dtype(dtype);
+  }
+
+  void SetOutputDataTypeForKey(const string& key, DataType dtype) {
+    (*signature_def_.mutable_outputs())[key].set_dtype(dtype);
+  }
+
+  void EraseOutputKey(const string& key) {
+    (*signature_def_.mutable_outputs()).erase(key);
+  }
+
+  void ExpectInvalidSignature() {
+    EXPECT_FALSE(IsValidSignature(signature_def_));
+  }
+
+  void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); }
+
+  SignatureDef signature_def_;
+};
+
+TEST_F(IsValidSignatureTest, IsValidPredictSignature) {
+  signature_def_.set_method_name("not_kPredictMethodName");
+  // Incorrect method name
+  ExpectInvalidSignature();
+
+  signature_def_.set_method_name(kPredictMethodName);
+  // No inputs
+  ExpectInvalidSignature();
+
+  SetInputDataTypeForKey(kPredictInputs, DT_STRING);
+  // No outputs
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey(kPredictOutputs, DT_STRING);
+  ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidRegressionSignature) {
+  signature_def_.set_method_name("not_kRegressMethodName");
+  // Incorrect method name
+  ExpectInvalidSignature();
+
+  signature_def_.set_method_name(kRegressMethodName);
+  // No inputs
+  ExpectInvalidSignature();
+
+  SetInputDataTypeForKey(kRegressInputs, DT_STRING);
+  // No outputs
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey(kRegressOutputs, DT_STRING);
+  // Incorrect data type
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT);
+  ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidClassificationSignature) {
+  signature_def_.set_method_name("not_kClassifyMethodName");
+  // Incorrect method name
+  ExpectInvalidSignature();
+
+  signature_def_.set_method_name(kClassifyMethodName);
+  // No inputs
+  ExpectInvalidSignature();
+
+  SetInputDataTypeForKey(kClassifyInputs, DT_STRING);
+  // No outputs
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey("invalidKey", DT_FLOAT);
+  // Invalid key
+  ExpectInvalidSignature();
+
+  EraseOutputKey("invalidKey");
+  SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT);
+  // Invalid dtype for classes
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING);
+  // Valid without scores
+  ExpectValidSignature();
+
+  SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING);
+  // Invalid dtype for scores
+  ExpectInvalidSignature();
+
+  SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT);
+  // Valid with both classes and scores
+  ExpectValidSignature();
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
index e3b76bb..fd3dc1d 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -25,5 +25,4 @@
 
 # pylint: disable=wildcard-import
 from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
 # pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 8a0dbef..12dd72a 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -50,7 +50,7 @@
     return os.path.join(temp_dir, dirname)
 
   def test_saving_sequential_model(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.RepeatVector(3))
@@ -75,7 +75,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_saving_sequential_model_without_compile(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.RepeatVector(3))
@@ -92,7 +92,7 @@
       self.assertAllClose(ref_y, y, atol=1e-05)
 
   def test_saving_functional_model(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.layers.Input(shape=(3,))
       x = keras.layers.Dense(2)(inputs)
       output = keras.layers.Dense(3)(x)
@@ -117,7 +117,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_saving_functional_model_without_compile(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.layers.Input(shape=(3,))
       x = keras.layers.Dense(2)(inputs)
       output = keras.layers.Dense(3)(x)
@@ -138,7 +138,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_saving_with_tf_optimizer(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.Dense(3))
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
deleted file mode 100644
index f521647..0000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SignatureDef utility functions implementation."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-def get_signature_def_by_key(meta_graph_def, signature_def_key):
-  """Utility function to get a SignatureDef protocol buffer by its key.
-
-  Args:
-    meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefMap to
-      look up.
-    signature_def_key: Key of the SignatureDef protocol buffer to find in the
-      SignatureDefMap.
-
-  Returns:
-    A SignatureDef protocol buffer corresponding to the supplied key, if it
-    exists.
-
-  Raises:
-    ValueError: If no entry corresponding to the supplied key is found in the
-    SignatureDefMap of the MetaGraphDef.
-  """
-  if signature_def_key not in meta_graph_def.signature_def:
-    raise ValueError("No SignatureDef with key '%s' found in MetaGraphDef." %
-                     signature_def_key)
-  return meta_graph_def.signature_def[signature_def_key]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
deleted file mode 100644
index d2e14f7..0000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for SignatureDef utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils as signature_def_contrib_utils
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.saved_model import utils
-
-
-class SignatureDefUtilsTest(test.TestCase):
-
-  def _add_to_signature_def_map(self, meta_graph_def, signature_def_map=None):
-    if signature_def_map is not None:
-      for key in signature_def_map:
-        meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
-
-  def _check_tensor_info(self, tensor_info_map, map_key, expected_tensor_name):
-    actual_tensor_info = tensor_info_map[map_key]
-    self.assertEqual(expected_tensor_name, actual_tensor_info.name)
-
-  def testGetSignatureDefByKey(self):
-    x = array_ops.placeholder(dtypes.float32, 1, name="x")
-    x_tensor_info = utils.build_tensor_info(x)
-
-    y = array_ops.placeholder(dtypes.float32, name="y")
-    y_tensor_info = utils.build_tensor_info(y)
-
-    foo_signature_def = signature_def_utils.build_signature_def({
-        "foo-input": x_tensor_info
-    }, {"foo-output": y_tensor_info}, "foo-method-name")
-    bar_signature_def = signature_def_utils.build_signature_def({
-        "bar-input": x_tensor_info
-    }, {"bar-output": y_tensor_info}, "bar-method-name")
-    meta_graph_def = meta_graph_pb2.MetaGraphDef()
-    self._add_to_signature_def_map(
-        meta_graph_def, {"foo": foo_signature_def,
-                         "bar": bar_signature_def})
-
-    # Look up a key that does not exist in the SignatureDefMap.
-    missing_key = "missing-key"
-    with self.assertRaisesRegexp(
-        ValueError,
-        "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key):
-      signature_def_contrib_utils.get_signature_def_by_key(
-          meta_graph_def, missing_key)
-
-    # Look up the key, `foo` which exists in the SignatureDefMap.
-    foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
-        meta_graph_def, "foo")
-    self.assertTrue("foo-method-name", foo_signature_def.method_name)
-
-    # Check inputs in signature def.
-    self.assertEqual(1, len(foo_signature_def.inputs))
-    self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")
-
-    # Check outputs in signature def.
-    self.assertEqual(1, len(foo_signature_def.outputs))
-    self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")
-
-    # Look up the key, `bar` which exists in the SignatureDefMap.
-    bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
-        meta_graph_def, "bar")
-    self.assertTrue("bar-method-name", bar_signature_def.method_name)
-
-    # Check inputs in signature def.
-    self.assertEqual(1, len(bar_signature_def.inputs))
-    self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")
-
-    # Check outputs in signature def.
-    self.assertEqual(1, len(bar_signature_def.outputs))
-    self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
-
-  def testGetSignatureDefByKeyRegression(self):
-    input1 = constant_op.constant("a", name="input-1")
-    output1 = constant_op.constant(7.2, name="output-1")
-
-    meta_graph_def = meta_graph_pb2.MetaGraphDef()
-    self._add_to_signature_def_map(meta_graph_def, {
-        "my_regression":
-            signature_def_utils.regression_signature_def(input1, output1)
-    })
-
-    # Look up the regression signature with the key used while saving.
-    signature_def = signature_def_contrib_utils.get_signature_def_by_key(
-        meta_graph_def, "my_regression")
-
-    # Check the method name to match the constants regression method name.
-    self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
-                     signature_def.method_name)
-
-    # Check inputs in signature def.
-    self.assertEqual(1, len(signature_def.inputs))
-    self._check_tensor_info(signature_def.inputs,
-                            signature_constants.REGRESS_INPUTS, "input-1:0")
-
-    # Check outputs in signature def.
-    self.assertEqual(1, len(signature_def.outputs))
-    self._check_tensor_info(signature_def.outputs,
-                            signature_constants.REGRESS_OUTPUTS, "output-1:0")
-
-  def testGetSignatureDefByKeyClassification(self):
-    input1 = constant_op.constant("a", name="input-1")
-    output1 = constant_op.constant("b", name="output-1")
-    output2 = constant_op.constant(3.0, name="output-2")
-
-    meta_graph_def = meta_graph_pb2.MetaGraphDef()
-    self._add_to_signature_def_map(meta_graph_def, {
-        "my_classification":
-            signature_def_utils.classification_signature_def(
-                input1, output1, output2)
-    })
-
-    # Look up the classification signature def with the key used while saving.
-    signature_def = signature_def_contrib_utils.get_signature_def_by_key(
-        meta_graph_def, "my_classification")
-
-    # Check the method name to match the constants classification method name.
-    self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
-                     signature_def.method_name)
-
-    # Check inputs in signature def.
-    self.assertEqual(1, len(signature_def.inputs))
-    self._check_tensor_info(signature_def.inputs,
-                            signature_constants.CLASSIFY_INPUTS, "input-1:0")
-
-    # Check outputs in signature def.
-    self.assertEqual(2, len(signature_def.outputs))
-    self._check_tensor_info(signature_def.outputs,
-                            signature_constants.CLASSIFY_OUTPUT_CLASSES,
-                            "output-1:0")
-    self._check_tensor_info(signature_def.outputs,
-                            signature_constants.CLASSIFY_OUTPUT_SCORES,
-                            "output-2:0")
-
-  def testPredictionSignatureDef(self):
-    input1 = constant_op.constant("a", name="input-1")
-    input2 = constant_op.constant("b", name="input-2")
-    output1 = constant_op.constant("c", name="output-1")
-    output2 = constant_op.constant("d", name="output-2")
-
-    meta_graph_def = meta_graph_pb2.MetaGraphDef()
-    self._add_to_signature_def_map(meta_graph_def, {
-        "my_prediction":
-            signature_def_utils.predict_signature_def({
-                "input-1": input1,
-                "input-2": input2
-            }, {"output-1": output1,
-                "output-2": output2})
-    })
-
-    # Look up the prediction signature def with the key used while saving.
-    signature_def = signature_def_contrib_utils.get_signature_def_by_key(
-        meta_graph_def, "my_prediction")
-    self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
-                     signature_def.method_name)
-
-    # Check inputs in signature def.
-    self.assertEqual(2, len(signature_def.inputs))
-    self._check_tensor_info(signature_def.inputs, "input-1", "input-1:0")
-    self._check_tensor_info(signature_def.inputs, "input-2", "input-2:0")
-
-    # Check outputs in signature def.
-    self.assertEqual(2, len(signature_def.outputs))
-    self._check_tensor_info(signature_def.outputs, "output-1", "output-1:0")
-    self._check_tensor_info(signature_def.outputs, "output-2", "output-2:0")
-
-
-if __name__ == "__main__":
-  test.main()
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index f5b6b1b..5e28e65 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -248,6 +248,7 @@
     self.vocab_size = 5
     self.end_token = 0
     self.length_penalty_weight = 0.6
+    self.coverage_penalty_weight = 0.0
 
   def test_step(self):
     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@@ -258,7 +259,8 @@
         lengths=constant_op.constant(
             2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
         finished=array_ops.zeros(
-            [self.batch_size, self.beam_width], dtype=dtypes.bool))
+            [self.batch_size, self.beam_width], dtype=dtypes.bool),
+        accumulated_attention_probs=())
 
     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                       0.0001)
@@ -281,7 +283,8 @@
         batch_size=ops.convert_to_tensor(self.batch_size),
         beam_width=self.beam_width,
         end_token=self.end_token,
-        length_penalty_weight=self.length_penalty_weight)
+        length_penalty_weight=self.length_penalty_weight,
+        coverage_penalty_weight=self.coverage_penalty_weight)
 
     with self.cached_session() as sess:
       outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -313,7 +316,8 @@
         lengths=ops.convert_to_tensor(
             [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
         finished=ops.convert_to_tensor(
-            [[False, True, False], [False, False, True]], dtype=dtypes.bool))
+            [[False, True, False], [False, False, True]], dtype=dtypes.bool),
+        accumulated_attention_probs=())
 
     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                       0.0001)
@@ -336,7 +340,8 @@
         batch_size=ops.convert_to_tensor(self.batch_size),
         beam_width=self.beam_width,
         end_token=self.end_token,
-        length_penalty_weight=self.length_penalty_weight)
+        length_penalty_weight=self.length_penalty_weight,
+        coverage_penalty_weight=self.coverage_penalty_weight)
 
     with self.cached_session() as sess:
       outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -372,6 +377,7 @@
     self.vocab_size = 5
     self.end_token = 0
     self.length_penalty_weight = 0.6
+    self.coverage_penalty_weight = 0.0
 
   def test_step(self):
 
@@ -411,7 +417,8 @@
         cell_state=dummy_cell_state,
         log_probs=log_probs,
         lengths=_lengths,
-        finished=_finished)
+        finished=_finished,
+        accumulated_attention_probs=())
 
     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
                       0.0001)
@@ -434,7 +441,8 @@
         batch_size=ops.convert_to_tensor(self.batch_size),
         beam_width=self.beam_width,
         end_token=self.end_token,
-        length_penalty_weight=self.length_penalty_weight)
+        length_penalty_weight=self.length_penalty_weight,
+        coverage_penalty_weight=self.coverage_penalty_weight)
 
     with self.cached_session() as sess:
       outputs_, next_state_, _, _ = sess.run(
@@ -476,7 +484,9 @@
       embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
       cell = rnn_cell.LSTMCell(cell_depth)
       initial_state = cell.zero_state(batch_size, dtypes.float32)
+      coverage_penalty_weight = 0.0
       if has_attention:
+        coverage_penalty_weight = 0.2
         inputs = array_ops.placeholder_with_default(
             np.random.randn(batch_size, decoder_max_time, input_depth).astype(
                 np.float32),
@@ -508,7 +518,8 @@
           initial_state=cell_state,
           beam_width=beam_width,
           output_layer=output_layer,
-          length_penalty_weight=0.0)
+          length_penalty_weight=0.0,
+          coverage_penalty_weight=coverage_penalty_weight)
 
       final_outputs, final_state, final_sequence_lengths = (
           decoder.dynamic_decode(
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 74741a7..605e314 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -21,6 +21,7 @@
 import collections
 import numpy as np
 
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
 from tensorflow.contrib.seq2seq.python.ops import decoder
 from tensorflow.python.framework import dtypes
@@ -49,7 +50,8 @@
 
 class BeamSearchDecoderState(
     collections.namedtuple("BeamSearchDecoderState",
-                           ("cell_state", "log_probs", "finished", "lengths"))):
+                           ("cell_state", "log_probs", "finished", "lengths",
+                            "accumulated_attention_probs"))):
   pass
 
 
@@ -260,6 +262,10 @@
     decoder_initial_state = decoder_initial_state.clone(
         cell_state=tiled_encoder_final_state)
     ```
+
+    Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
+    when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
+    the translation to cover all inputs.
   """
 
   def __init__(self,
@@ -271,6 +277,7 @@
                beam_width,
                output_layer=None,
                length_penalty_weight=0.0,
+               coverage_penalty_weight=0.0,
                reorder_tensor_arrays=True):
     """Initialize the BeamSearchDecoder.
 
@@ -286,6 +293,8 @@
         `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
         to storing the result or sampling.
       length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+      coverage_penalty_weight: Float weight to penalize the coverage of source
+        sentence. Disabled with 0.0.
       reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
         state will be reordered according to the beam search path. If the
         `TensorArray` can be reordered, the stacked form will be returned.
@@ -326,6 +335,7 @@
     self._batch_size = array_ops.size(start_tokens)
     self._beam_width = beam_width
     self._length_penalty_weight = length_penalty_weight
+    self._coverage_penalty_weight = coverage_penalty_weight
     self._initial_cell_state = nest.map_structure(
         self._maybe_split_batch_beams, initial_state, self._cell.state_size)
     self._start_tokens = array_ops.tile(
@@ -411,13 +421,18 @@
         on_value=ops.convert_to_tensor(0.0, dtype=dtype),
         off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
         dtype=dtype)
+    init_attention_probs = get_attention_probs(
+        self._initial_cell_state, self._coverage_penalty_weight)
+    if init_attention_probs is None:
+      init_attention_probs = ()
 
     initial_state = BeamSearchDecoderState(
         cell_state=self._initial_cell_state,
         log_probs=log_probs,
         finished=finished,
         lengths=array_ops.zeros(
-            [self._batch_size, self._beam_width], dtype=dtypes.int64))
+            [self._batch_size, self._beam_width], dtype=dtypes.int64),
+        accumulated_attention_probs=init_attention_probs)
 
     return (finished, start_inputs, initial_state)
 
@@ -631,6 +646,7 @@
     beam_width = self._beam_width
     end_token = self._end_token
     length_penalty_weight = self._length_penalty_weight
+    coverage_penalty_weight = self._coverage_penalty_weight
 
     with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
       cell_state = state.cell_state
@@ -655,7 +671,8 @@
           batch_size=batch_size,
           beam_width=beam_width,
           end_token=end_token,
-          length_penalty_weight=length_penalty_weight)
+          length_penalty_weight=length_penalty_weight,
+          coverage_penalty_weight=coverage_penalty_weight)
 
       finished = beam_search_state.finished
       sample_ids = beam_search_output.predicted_ids
@@ -667,7 +684,8 @@
 
 
 def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
-                      beam_width, end_token, length_penalty_weight):
+                      beam_width, end_token, length_penalty_weight,
+                      coverage_penalty_weight):
   """Performs a single step of Beam Search Decoding.
 
   Args:
@@ -684,6 +702,8 @@
     beam_width: Python int.  The size of the beams.
     end_token: The int32 end token.
     length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+    coverage_penalty_weight: Float weight to penalize the coverage of source
+      sentence. Disabled with 0.0.
 
   Returns:
     A new beam state.
@@ -693,6 +713,7 @@
   # Calculate the current lengths of the predictions
   prediction_lengths = beam_state.lengths
   previously_finished = beam_state.finished
+  not_finished = math_ops.logical_not(previously_finished)
 
   # Calculate the total log probs for the new hypotheses
   # Final Shape: [batch_size, beam_width, vocab_size]
@@ -708,16 +729,29 @@
       on_value=np.int64(0),
       off_value=np.int64(1),
       dtype=dtypes.int64)
-  add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+  add_mask = math_ops.to_int64(not_finished)
   lengths_to_add *= array_ops.expand_dims(add_mask, 2)
   new_prediction_lengths = (
       lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
 
+  # Calculate the accumulated attention probabilities if coverage penalty is
+  # enabled.
+  accumulated_attention_probs = None
+  attention_probs = get_attention_probs(
+      next_cell_state, coverage_penalty_weight)
+  if attention_probs is not None:
+    attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2)
+    accumulated_attention_probs = (
+        beam_state.accumulated_attention_probs + attention_probs)
+
   # Calculate the scores for each beam
   scores = _get_scores(
       log_probs=total_probs,
       sequence_lengths=new_prediction_lengths,
-      length_penalty_weight=length_penalty_weight)
+      length_penalty_weight=length_penalty_weight,
+      coverage_penalty_weight=coverage_penalty_weight,
+      finished=previously_finished,
+      accumulated_attention_probs=accumulated_attention_probs)
 
   time = ops.convert_to_tensor(time, name="time")
   # During the first time step we only consider the initial beam
@@ -775,6 +809,15 @@
       range_size=beam_width,
       gather_shape=[-1])
   next_prediction_len += lengths_to_add
+  next_accumulated_attention_probs = ()
+  if accumulated_attention_probs is not None:
+    next_accumulated_attention_probs = _tensor_gather_helper(
+        gather_indices=next_beam_ids,
+        gather_from=accumulated_attention_probs,
+        batch_size=batch_size,
+        range_size=beam_width,
+        gather_shape=[batch_size * beam_width, -1],
+        name="next_accumulated_attention_probs")
 
   # Pick out the cell_states according to the next_beam_ids. We use a
   # different gather_shape here because the cell_state tensors, i.e.
@@ -795,7 +838,8 @@
       cell_state=next_cell_state,
       log_probs=next_beam_probs,
       lengths=next_prediction_len,
-      finished=next_finished)
+      finished=next_finished,
+      accumulated_attention_probs=next_accumulated_attention_probs)
 
   output = BeamSearchDecoderOutput(
       scores=next_beam_scores,
@@ -805,7 +849,53 @@
   return output, next_state
 
 
-def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
+def get_attention_probs(next_cell_state, coverage_penalty_weight):
+  """Get attention probabilities from the cell state.
+
+  Args:
+    next_cell_state: The next state from the cell, e.g. an instance of
+      AttentionWrapperState if the cell is attentional.
+    coverage_penalty_weight: Float weight to penalize the coverage of source
+      sentence. Disabled with 0.0.
+
+  Returns:
+    The attention probabilities with shape `[batch_size, beam_width, max_time]`
+    if coverage penalty is enabled. Otherwise, returns None.
+
+  Raises:
+    ValueError: If no cell is attentional but coverage penalty is enabled.
+  """
+  if coverage_penalty_weight == 0.0:
+    return None
+
+  # Attention probabilities of each attention layer. Each with shape
+  # `[batch_size, beam_width, max_time]`.
+  probs_per_attn_layer = []
+  if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
+    probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
+  elif isinstance(next_cell_state, tuple):
+    for state in next_cell_state:
+      if isinstance(state, attention_wrapper.AttentionWrapperState):
+        probs_per_attn_layer.append(attention_probs_from_attn_state(state))
+
+  if not probs_per_attn_layer:
+    raise ValueError(
+        "coverage_penalty_weight must be 0.0 if no cell is attentional.")
+
+  if len(probs_per_attn_layer) == 1:
+    attention_probs = probs_per_attn_layer[0]
+  else:
+    # Calculate the average attention probabilities from all attention layers.
+    attention_probs = [
+        array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
+    attention_probs = array_ops.concat(attention_probs, -1)
+    attention_probs = math_ops.reduce_mean(attention_probs, -1)
+
+  return attention_probs
+
+
+def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
+                coverage_penalty_weight, finished, accumulated_attention_probs):
   """Calculates scores for beam search hypotheses.
 
   Args:
@@ -813,13 +903,78 @@
       `[batch_size, beam_width, vocab_size]`.
     sequence_lengths: The array of sequence lengths.
     length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+    coverage_penalty_weight: Float weight to penalize the coverage of source
+      sentence. Disabled with 0.0.
+    finished: A boolean tensor of shape `[batch_size, beam_width]` that
+      specifies which elements in the beam are finished already.
+    accumulated_attention_probs: Accumulated attention probabilities up to the
+      current time step, with shape `[batch_size, beam_width, max_time]` if
+      coverage_penalty_weight is not 0.0.
 
   Returns:
-    The scores normalized by the length_penalty.
+    The scores normalized by the length_penalty and coverage_penalty.
+
+  Raises:
+    ValueError: accumulated_attention_probs is None when coverage penalty is
+      enabled.
   """
   length_penalty_ = _length_penalty(
       sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
-  return log_probs / length_penalty_
+  scores = log_probs / length_penalty_
+
+  coverage_penalty_weight = ops.convert_to_tensor(
+      coverage_penalty_weight, name="coverage_penalty_weight")
+  if coverage_penalty_weight.shape.ndims != 0:
+    raise ValueError("coverage_penalty_weight should be a scalar, "
+                     "but saw shape: %s" % coverage_penalty_weight.shape)
+
+  if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
+    return scores
+
+  if accumulated_attention_probs is None:
+    raise ValueError(
+        "accumulated_attention_probs can be None only if coverage penalty is "
+        "disabled.")
+
+  # Add source sequence length mask before computing coverage penalty.
+  accumulated_attention_probs = array_ops.where(
+      math_ops.equal(accumulated_attention_probs, 0.0),
+      array_ops.ones_like(accumulated_attention_probs),
+      accumulated_attention_probs)
+
+  # coverage penalty =
+  #     sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
+  coverage_penalty = math_ops.reduce_sum(
+      math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
+  # Apply coverage penalty to finished predictions.
+  coverage_penalty *= math_ops.to_float(finished)
+  weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
+  # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
+  weighted_coverage_penalty = array_ops.expand_dims(
+      weighted_coverage_penalty, 2)
+  return scores + weighted_coverage_penalty
+
+
+def attention_probs_from_attn_state(attention_state):
+  """Calculates the average attention probabilities.
+
+  Args:
+    attention_state: An instance of `AttentionWrapperState`.
+
+  Returns:
+    The attention probabilities in the given AttentionWrapperState.
+    If there're multiple attention mechanisms, return the average value from
+    all attention mechanisms.
+  """
+  # Attention probabilities over time steps, with shape
+  # `[batch_size, beam_width, max_time]`.
+  attention_probs = attention_state.alignments
+  if isinstance(attention_probs, tuple):
+    attention_probs = [
+        array_ops.expand_dims(prob, -1) for prob in attention_probs]
+    attention_probs = array_ops.concat(attention_probs, -1)
+    attention_probs = math_ops.reduce_mean(attention_probs, -1)
+  return attention_probs
 
 
 def _length_penalty(sequence_lengths, penalty_factor):
diff --git a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
index 1bb6fbc..795de6a 100644
--- a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
@@ -88,7 +88,7 @@
     height = 300
     width = 280
 
-    with self.test_session():
+    with self.cached_session():
       test_dataset = _create_tfrecord_dataset(dataset_dir)
       provider = dataset_data_provider.DatasetDataProvider(test_dataset)
       key, image, label = provider.get(['record_key', 'image', 'label'])
@@ -111,7 +111,7 @@
     height = 300
     width = 280
 
-    with self.test_session():
+    with self.cached_session():
       provider = dataset_data_provider.DatasetDataProvider(
           _create_tfrecord_dataset(dataset_dir))
     [image] = provider.get(['image'])
@@ -128,7 +128,7 @@
     dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
                                                        'tfrecord_dataset'))
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         dataset_data_provider.DatasetDataProvider(
             _create_tfrecord_dataset(dataset_dir), record_key='image')
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index ea8cc0f..c457d44 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -39,7 +39,7 @@
     ops.reset_default_graph()
 
   def _verify_all_data_sources_read(self, shared_queue):
-    with self.test_session():
+    with self.cached_session():
       tfrecord_paths = test_utils.create_tfrecord_files(
           self.get_temp_dir(), num_files=3)
 
@@ -76,7 +76,7 @@
     self.assertEquals(count0 + count1 + count2, num_reads)
 
   def _verify_read_up_to_out(self, shared_queue):
-    with self.test_session():
+    with self.cached_session():
       num_files = 3
       num_records_per_file = 7
       tfrecord_paths = test_utils.create_tfrecord_files(
@@ -161,7 +161,7 @@
     ops.reset_default_graph()
 
   def testTFRecordReader(self):
-    with self.test_session():
+    with self.cached_session():
       self._tfrecord_paths = test_utils.create_tfrecord_files(
           self.get_temp_dir(), num_files=3)
 
@@ -188,7 +188,7 @@
     ops.reset_default_graph()
 
   def testOutOfRangeError(self):
-    with self.test_session():
+    with self.cached_session():
       [tfrecord_path] = test_utils.create_tfrecord_files(
           self.get_temp_dir(), num_files=1)
 
@@ -196,7 +196,7 @@
         tfrecord_path, reader_class=io_ops.TFRecordReader)
     init_op = variables.local_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with queues.QueueRunners(sess):
         num_reads = 11
@@ -205,7 +205,7 @@
             sess.run([key, value])
 
   def testTFRecordReader(self):
-    with self.test_session():
+    with self.cached_session():
       [tfrecord_path] = test_utils.create_tfrecord_files(
           self.get_temp_dir(), num_files=1)
 
@@ -213,7 +213,7 @@
         tfrecord_path, reader_class=io_ops.TFRecordReader)
     init_op = variables.local_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with queues.QueueRunners(sess):
         flowers = 0
diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
index 6c3e57c..7caa42d 100644
--- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
@@ -37,7 +37,7 @@
 class PrefetchQueueTest(test.TestCase):
 
   def testOneThread(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       image_size = 32
       num_batches = 5
@@ -74,7 +74,7 @@
         thread.join()
 
   def testMultiThread(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       image_size = 32
       num_batches = 5
@@ -114,7 +114,7 @@
         thread.join()
 
   def testMultipleDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       image_size = 32
       num_batches = 4
@@ -162,7 +162,7 @@
         prefetch_queue.prefetch_queue([variable_tensor])
 
   def testDynamicPad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create 3 tensors of variable but compatible shapes.
       var_shape = [None, 2]
       p1 = constant_op.constant([[1, 2], [3, 4]])
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 826242c..3114949 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -45,7 +45,7 @@
         int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
 
   def _EncodedBytesFeature(self, tf_encoded):
-    with self.test_session():
+    with self.cached_session():
       encoded = tf_encoded.eval()
 
     def BytesList(value):
@@ -133,7 +133,7 @@
     tf_image = self.DecodeExample(serialized_example, item_handler,
                                   image_format)
 
-    with self.test_session():
+    with self.cached_session():
       decoded_image = tf_image.eval()
 
       # We need to recast them here to avoid some issues with uint8.
@@ -265,7 +265,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'labels':
@@ -296,7 +296,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
@@ -319,7 +319,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
@@ -342,7 +342,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -366,7 +366,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'labels':
@@ -390,7 +390,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -423,7 +423,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -468,7 +468,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -505,7 +505,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -536,7 +536,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -567,7 +567,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -598,7 +598,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
       keys_to_features = {
           'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -625,7 +625,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
 
       keys_to_features = {
@@ -657,7 +657,7 @@
 
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
 
       keys_to_features = {
@@ -692,7 +692,7 @@
       image, serialized_example = self.GenerateImage(
           image_format=image_encoding, image_shape=image_shape)
 
-      with self.test_session():
+      with self.cached_session():
 
         def ConditionalDecoding(keys_to_tensors):
           """See base class."""
@@ -759,7 +759,7 @@
             }))
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
 
       keys_to_features = {
@@ -800,7 +800,7 @@
             }))
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
 
       keys_to_features = {
@@ -837,7 +837,7 @@
     image, _ = self.GenerateImage(
         image_format=image_format, image_shape=image_shape)
     tf_encoded = self._Encoder(image, image_format)
-    with self.test_session():
+    with self.cached_session():
       tf_string = tf_encoded.eval()
 
     example = example_pb2.Example(
@@ -852,7 +852,7 @@
             }))
     serialized_example = example.SerializeToString()
 
-    with self.test_session():
+    with self.cached_session():
       serialized_example = array_ops.reshape(serialized_example, shape=[])
 
       decoder = tfexample_decoder.TFExampleDecoder(
@@ -885,7 +885,7 @@
     table = lookup_ops.index_table_from_tensor(
         constant_op.constant(['dog', 'guinea pig', 'cat']))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(lookup_ops.tables_initializer())
 
       serialized_example = array_ops.reshape(serialized_example, shape=[])
@@ -943,7 +943,7 @@
     decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
                                                  items_to_handlers)
     obtained_class_ids_each_example = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(lookup_ops.tables_initializer())
       for example in [example1, example2, example3]:
         serialized_example = array_ops.reshape(
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
index 4707dc2..8fcd7ae 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
@@ -47,7 +47,7 @@
         low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
     tol = 1e-12 if dtype_ == np.float64 else 1e-5
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if use_static_shape_:
         a = constant_op.constant(a_np)
       else:
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
index a736427..2a91009 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
@@ -47,7 +47,7 @@
         low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
     tol = 1e-12 if dtype_ == np.float64 else 1e-6
     max_iter = 20
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if use_static_shape_:
         a = constant_op.constant(a_np)
         rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
index a128284..a0e6eb8 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
@@ -54,7 +54,7 @@
     x_np = np.zeros_like(rhs_np)
     tol = 1e-6 if dtype_ == np.float64 else 1e-3
     max_iter = 20
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if use_static_shape_:
         a = constant_op.constant(a_np)
         rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 5d75346..57b4996 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -33,7 +33,7 @@
       a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
       x_np = np.array([[2.], [-3.]], dtype=dtype)
       y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         if use_static_shape_:
           a = constant_op.constant(a_np, dtype=dtype)
           x = constant_op.constant(x_np, dtype=dtype)
@@ -68,7 +68,7 @@
       a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
       x_np = np.array([[2.], [-3.]], dtype=dtype)
       y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         if use_static_shape_:
           a = constant_op.constant(a_np, dtype=dtype)
           x = constant_op.constant(x_np, dtype=dtype)
@@ -101,7 +101,7 @@
     self._testIdentityOperator(False)
 
   def testL2Norm(self):
-    with self.test_session():
+    with self.cached_session():
       x_np = np.array([[2], [-3.], [5.]])
       x_norm_np = np.linalg.norm(x_np)
       x_normalized_np = x_np / x_norm_np
diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py
index 9a4ad36..b7ce6aa 100644
--- a/tensorflow/contrib/specs/python/specs_test.py
+++ b/tensorflow/contrib/specs/python/specs_test.py
@@ -38,7 +38,7 @@
 class SpecsTest(test.TestCase):
 
   def testSimpleConv(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 18, 19, 5))
       spec = "net = Cr(64, [5, 5])"
       outputs = specs.create_net(spec, inputs)
@@ -53,7 +53,7 @@
   def testUnary(self):
     # This is just a quick and dirty check that these ops exist
     # and work as unary ops.
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(17, 55))
       spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax"
       outputs = specs.create_net(spec, inputs)
@@ -63,7 +63,7 @@
       self.assertEqual(tuple(result.shape), (17, 55))
 
   def testAdd(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(17, 55))
       spec = "net = Fs(10) + Fr(10)"
       outputs = specs.create_net(spec, inputs)
@@ -77,7 +77,7 @@
           "<> variablev2 dot variablev2 biasadd relu add")
 
   def testMpPower(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 64, 64, 5))
       spec = "M2 = Mp([2, 2]); net = M2**3"
       outputs = specs.create_net(spec, inputs)
@@ -90,7 +90,7 @@
           "_ maxpool maxpool maxpool")
 
   def testAbbrevPower(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 64, 64, 5))
       spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3"
       outputs = specs.create_net(spec, inputs)
@@ -106,7 +106,7 @@
           " biasadd relu maxpool")
 
   def testAbbrevPower2(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 64, 64, 5))
       spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);"
       spec += "net = (C3(_0=5) | M2)**3"
@@ -123,7 +123,7 @@
           " maxpool")
 
   def testConc(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(10, 20))
       spec = "net = Conc(1, Fs(20), Fs(10))"
       outputs = specs.create_net(spec, inputs)
@@ -137,7 +137,7 @@
           "<> variablev2 dot variablev2 biasadd sig _ concatv2")
 
   def testImport(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(10, 20))
       spec = ("S = Import('from tensorflow.python.ops" +
               " import math_ops; f = math_ops.sigmoid')")
@@ -150,7 +150,7 @@
       self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig")
 
   def testKeywordRestriction(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(10, 20))
       spec = "import re; net = Conc(1, Fs(20), Fs(10))"
       self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs))
@@ -179,7 +179,7 @@
   # XXX: the cleverness of this code is over 9000
   # TODO: original author please fix
   def DISABLED_testVar(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with specs.ops:
         # pylint: disable=undefined-variable
         v = Var("test_var",
@@ -196,7 +196,7 @@
   # XXX: the cleverness of this code is over 9000
   # TODO: original author please fix
   def DISABLED_testShared(self):
-    with self.test_session():
+    with self.cached_session():
       with specs.ops:
         # pylint: disable=undefined-variable
         f = Shared(Fr(100))
diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py
index 34ff4bc..b82ba06 100644
--- a/tensorflow/contrib/specs/python/summaries_test.py
+++ b/tensorflow/contrib/specs/python/summaries_test.py
@@ -34,7 +34,7 @@
 class SummariesTest(test.TestCase):
 
   def testStructure(self):
-    with self.test_session():
+    with self.cached_session():
       inputs_shape = (1, 18, 19, 5)
       inputs = constant_op.constant(_rand(*inputs_shape))
       spec = "net = Cr(64, [5, 5])"
@@ -48,7 +48,7 @@
           "_ variablev2 conv variablev2 biasadd relu")
 
   def testStructureFromTensor(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 18, 19, 5))
       spec = "net = Cr(64, [5, 5])"
       outputs = specs.create_net(spec, inputs)
@@ -60,7 +60,7 @@
           "_ variablev2 conv variablev2 biasadd relu")
 
   def testPrint(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 18, 19, 5))
       spec = "net = Cr(64, [5, 5])"
       outputs = specs.create_net(spec, inputs)
@@ -70,7 +70,7 @@
       summaries.tf_spec_print(spec, inputs)
 
   def testSummary(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = constant_op.constant(_rand(1, 18, 19, 5))
       spec = "net = Cr(64, [5, 5])"
       outputs = specs.create_net(spec, inputs)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 652f709..00c855d 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -462,7 +462,10 @@
     size = "small",
     srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
     srcs_version = "PY2AND3",
-    tags = ["no_pip_gpu"],
+    tags = [
+        "no_gpu",
+        "no_pip_gpu",
+    ],
     deps = [
         ":tensor_forest_ops_py",
         "//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index db970de..0042d37 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -134,19 +134,19 @@
           weight_column=weights_name,
           label_dimension=params.num_outputs,
           name=name,
-          loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+          loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
     else:
       if params.num_classes == 2:
         return core_head_lib.binary_classification_head(
             weight_column=weights_name,
             name=name,
-            loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+            loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
       else:
         return core_head_lib.multi_class_head(
             n_classes=params.num_classes,
             weight_column=weights_name,
             name=name,
-            loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+            loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
 
 def get_model_fn(params,
                  graph_builder_class,
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index f80a34e..fe2c91c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -246,7 +246,8 @@
     const Tensor& input_weights = context->input(7);
     const Tensor& leaf_ids_tensor = context->input(8);
 
-    std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
+    std::unique_ptr<TensorDataSet> data_set(
+        new TensorDataSet(input_spec_, random_seed_));
     data_set->set_input_tensors(input_data, sparse_input_indices,
                                 sparse_input_values, sparse_input_shape);
 
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 122a67a..9e8979b 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -19,6 +19,7 @@
     "tf_gen_op_libs",
     "tf_gen_op_wrapper_py",
 )
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
 load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -181,7 +182,12 @@
     srcs_version = "PY2AND3",
     deps = [
         ":wrap_conversion",
+        "//tensorflow/python:graph_util",
+        "//tensorflow/python:session",
         "//tensorflow/python:tf_optimizer",
+        "//tensorflow/python/saved_model:builder",
+        "//tensorflow/python/saved_model:loader",
+        "//tensorflow/python/saved_model:tag_constants",
     ],
 )
 
@@ -410,6 +416,31 @@
     ],
 )
 
+cuda_py_test(
+    name = "trt_convert_test",
+    srcs = ["python/trt_convert_test.py"],
+    additional_deps = [
+        ":trt_convert_py",
+        ":trt_ops_py",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:graph_util",
+        "//tensorflow/python/saved_model:builder",
+        "//tensorflow/python/saved_model:loader",
+        "//tensorflow/python/saved_model:signature_constants",
+        "//tensorflow/python/saved_model:signature_def_utils",
+        "//tensorflow/python/saved_model:tag_constants",
+        "//tensorflow/python/saved_model:utils",
+        "//tensorflow/python/tools:freeze_graph_lib",
+        "//tensorflow/python/tools:saved_model_utils",
+    ],
+    tags = [
+        "no_cuda_on_cpu_tap",
+        "no_windows",
+        "nomac",
+    ],
+)
+
 cuda_py_tests(
     name = "tf_trt_integration_test",
     srcs = [
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 4116f2f..369e73b 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-# pylint: disable=unused-import,line-too-long
 import six as _six
+# pylint: disable=unused-import,line-too-long
 from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
 from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
 from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
@@ -28,55 +28,179 @@
 from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
 from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
 from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
 from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
 from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.grappler import tf_optimizer
 from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.training import saver
-# pylint: enable=unused-import,line-too-long
+
+if _six.PY2:
+  _to_bytes = lambda s: s
+  _to_string = lambda s: s
+else:
+  _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+  _to_string = lambda s: s.decode("utf-8")
+
+
+class TrtPrecisionMode(object):
+  FP32 = "FP32"
+  FP16 = "FP16"
+  INT8 = "INT8"
+
+  @staticmethod
+  def supported_precision_modes():
+    return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+def tensorrt_rewriter_config(max_batch_size=1,
+                             max_workspace_size_bytes=2 << 20,
+                             precision_mode=TrtPrecisionMode.FP32,
+                             minimum_segment_size=3,
+                             is_dynamic_op=False,
+                             maximum_cached_engines=1,
+                             cached_engine_batch_sizes=None):
+  """Returns a RewriterConfig proto for TRT transformation.
+
+  Args:
+    max_batch_size: max size for the input batch
+    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+      engine can use at execution time. This corresponds to the 'workspaceSize'
+      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+    minimum_segment_size: the minimum number of nodes required for a subgraph to
+      be replaced by TRTEngineOp.
+    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+      network and engine at run time.
+    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+      If the number of cached engines is already at max but none of them can
+      serve the input, the TRTEngineOp will fall back to run the TF function
+      based on which the TRTEngineOp is created.
+    cached_engine_batch_sizes: a list of batch sizes used to create cached
+      engines, only used when is_dynamic_op is True. The length of the list
+      should be smaller than maximum_cached_engines, and the dynamic TRT op will
+      use this list to determine the batch sizes of the cached engines, instead
+      of making the decision on the fly. This is useful when we know the most
+      common batch size(s) the application is going to generate.
+
+  Returns:
+    A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+  Raises:
+    TypeError: if the provided precision mode is invalid.
+    ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines.
+  """
+  if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
+    raise ValueError(("precision mode '{}' is not supported."
+                      "It should be one of {}").format(
+                          precision_mode,
+                          TrtPrecisionMode.supported_precision_modes))
+
+  rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+  rewriter_cfg.optimizers.extend(["constfold", "layout"])
+  optimizer = rewriter_cfg.custom_optimizers.add()
+  optimizer.name = "TensorRTOptimizer"
+  optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+  optimizer.parameter_map["max_batch_size"].i = max_batch_size
+  optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+  optimizer.parameter_map[
+      "max_workspace_size_bytes"].i = max_workspace_size_bytes
+  optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+  optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+  if cached_engine_batch_sizes:
+    if not isinstance(cached_engine_batch_sizes, list):
+      raise TypeError("cached_engine_batch_sizes should be a list.")
+    if len(cached_engine_batch_sizes) > maximum_cached_engines:
+      raise ValueError("cached_engine_batch_sizes should not contain more than "
+                       "maximum_cached_engines items.")
+    optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+        cached_engine_batch_sizes)
+  return rewriter_cfg
 
 
 def create_inference_graph(input_graph_def,
                            outputs,
                            max_batch_size=1,
                            max_workspace_size_bytes=2 << 20,
-                           precision_mode="FP32",
+                           precision_mode=TrtPrecisionMode.FP32,
                            minimum_segment_size=3,
                            is_dynamic_op=False,
                            maximum_cached_engines=1,
-                           cached_engine_batches=None):
+                           cached_engine_batch_sizes=None,
+                           input_saved_model_dir=None,
+                           input_saved_model_tags=None,
+                           output_saved_model_dir=None,
+                           session_config=None):
   """Python wrapper for the TRT transformation.
 
   Args:
-    input_graph_def: GraphDef object containing a model to be transformed.
-    outputs: list of tensors or node names for the model outputs.
-    max_batch_size: max size for the input batch
-    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
-    precision_mode: one of 'FP32', 'FP16' and 'INT8'
+    input_graph_def: a GraphDef object containing a model to be transformed. If
+      set to None, the graph will be read from the SavedModel loaded from
+      input_saved_model_dir.
+    outputs: list of tensors or node names for the model outputs. Only used when
+      input_graph_def is not None.
+    max_batch_size: max size for the input batch.
+    max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+      engine can use at execution time. This corresponds to the 'workspaceSize'
+      parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+    precision_mode: one of TrtPrecisionMode.supported_precision_modes().
     minimum_segment_size: the minimum number of nodes required for a subgraph to
       be replaced by TRTEngineOp.
     is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
       network and engine at run time.
     maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
-    cached_engine_batches: batch sizes used to pre-create cached engines.
+      If the number of cached engines is already at max but none of them can
+      serve the input, the TRTEngineOp will fall back to run the TF function
+      based on which the TRTEngineOp is created.
+    cached_engine_batch_sizes: a list of batch sizes used to create cached
+      engines, only used when is_dynamic_op is True. The length of the list
+      should be smaller than maximum_cached_engines, and the dynamic TRT op will
+      use this list to determine the batch sizes of the cached engines, instead
+      of making the decision on the fly. This is useful when we know the most
+      common batch size(s) the application is going to generate.
+    input_saved_model_dir: the directory to load the SavedModel which contains
+      the input graph to transforms. Used only when input_graph_def is None.
+    input_saved_model_tags: list of tags to load the SavedModel.
+    output_saved_model_dir: if not None, construct a SavedModel using the
+      returned GraphDef and save it to the specified directory. This option only
+      works when the input graph is loaded from a SavedModel, i.e. when
+      input_saved_model_dir is specified and input_graph_def is None.
+    session_config: the ConfigProto used to create a Session. If not specified,
+      a default ConfigProto will be used.
 
   Returns:
-    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+    A GraphDef transformed from input_graph_def (or the SavedModel graph def
+    loaded from input_saved_model_dir, if input_graph_def is not present), where
+    all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+    function is added for each of the subgraphs.
+
+    If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+    subgraph GraphDef, which will be converted to a TRT engine at execution time
+    and the TRT engine will be cached for future usage. A new TRT engine will be
+    created each time when none of the cached engines match the input shapes. If
+    it fails to execute the TRT engine or the number of cached engines reaches
+    maximum_cached_engines, the op will fall back to call the corresponding TF
+    function.
+
+    If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+    engine created from the corresponding subgraph. No more engines will be
+    created on the fly, and the op will fall back to call the corresponding TF
+    function when it fails to execute the engine.
 
   Raises:
-    ValueError: if the provided precision mode is invalid.
-    RuntimeError: if the returned status message is malformed.
+    ValueError: if the combination of the parameters is invalid.
+    RuntimeError: if the TensorRT library version is incompatible.
   """
-  supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
-  if precision_mode.upper() not in supported_precision_modes:
-    raise ValueError(("precision mode '{}' is not supported."
-                      "It should be one of {}").format(
-                          precision_mode, "{'FP32', 'FP16', 'INT8'}"))
-  mode = supported_precision_modes[precision_mode.upper()]
   compiled_version = get_linked_tensorrt_version()
   loaded_version = get_loaded_tensorrt_version()
   version_mismatch = False
@@ -101,61 +225,111 @@
     tf_logging.info("Running against TensorRT version %s" % ".".join(
         [str(x) for x in loaded_version]))
 
-  def py2bytes(inp):
-    return inp
+  if session_config is None:
+    session_config = config_pb2.ConfigProto()
 
-  def py3bytes(inp):
-    return inp.encode("utf-8", errors="surrogateescape")
+  if input_saved_model_tags is None:
+    input_saved_model_tags = [tag_constants.SERVING]
+  saved_model_loader = None
+  grappler_meta_graph_def = None
 
-  def py2string(inp):
-    return inp
+  if input_graph_def is None:
+    # Read from SavedModel and freeze the graph if necessary.
+    if input_saved_model_dir is None:
+      raise ValueError("input_graph_def and input_saved_model_dir cannot be "
+                       "both None")
+    with ops.Graph().as_default():
+      with session.Session(config=session_config) as sess:
+        saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
+        input_meta_graph_def = saved_model_loader.load(sess,
+                                                       input_saved_model_tags)
+        output_node_names = set()
 
-  def py3string(inp):
-    return inp.decode("utf-8")
+        def _gather_names(tensor_info):
+          """Get the node names from a TensorInfo."""
+          return set(
+              [tensor_info[key].name.split(":")[0] for key in tensor_info])
 
-  if _six.PY2:
-    to_bytes = py2bytes
-    to_string = py2string
+        # Get input and outputs from all SignatureDef.
+        for key in input_meta_graph_def.signature_def:
+          signature_def = input_meta_graph_def.signature_def[key]
+          output_node_names.update(_gather_names(signature_def.inputs))
+          output_node_names.update(_gather_names(signature_def.outputs))
+
+        # Freeze the variables in the SavedModel graph and copy the frozen
+        # graph over.
+        frozen_graph_def = graph_util.convert_variables_to_constants(
+            sess, sess.graph.as_graph_def(add_shapes=True),
+            list(output_node_names))
+        grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+        grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+        # Copy the collections that are not variables.
+        for key in input_meta_graph_def.collection_def:
+          # TODO(laigd): currently we use the collection key to filter out
+          # collections that depend on variable ops, but this may miss some
+          # other user-defined collections. A better way would be to use
+          # CollectionDef::NodeList for the filtering.
+          if key not in [
+              "variables", "local_variables", "model_variables",
+              "trainable_variables", "train_op", "table_initializer"
+          ]:
+            grappler_meta_graph_def.collection_def[key].CopyFrom(
+                input_meta_graph_def.collection_def[key])
+
+        # Copy other information.
+        grappler_meta_graph_def.meta_info_def.CopyFrom(
+            input_meta_graph_def.meta_info_def)
+        for key in input_meta_graph_def.signature_def:
+          grappler_meta_graph_def.signature_def[key].CopyFrom(
+              input_meta_graph_def.signature_def[key])
+        # TODO(laigd): maybe add back AssetFileDef.
   else:
-    to_bytes = py3bytes
-    to_string = py3string
-
-  # Create MetaGraphDef
-  graph = ops.Graph()
-  with graph.as_default():
-    importer.import_graph_def(input_graph_def, name="")
-  meta_graph = saver.export_meta_graph(
-      graph_def=graph.as_graph_def(), graph=graph)
-  if outputs:
-    output_collection = meta_graph_pb2.CollectionDef()
-    output_list = output_collection.node_list.value
-    for i in outputs:
-      if isinstance(i, ops.Tensor):
-        output_list.append(to_bytes(i.name))
-      else:
-        output_list.append(to_bytes(i))
-    meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+    if output_saved_model_dir is not None:
+      raise ValueError("output_saved_model_dir cannot be set when "
+                       "input_graph_def is set")
+    # Create MetaGraphDef from input graph.
+    graph = ops.Graph()
+    with graph.as_default():
+      importer.import_graph_def(input_graph_def, name="")
+    grappler_meta_graph_def = saver.export_meta_graph(
+        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+    if outputs:
+      output_collection = meta_graph_pb2.CollectionDef()
+      output_list = output_collection.node_list.value
+      for i in outputs:
+        if isinstance(i, ops.Tensor):
+          output_list.append(_to_bytes(i.name))
+        else:
+          output_list.append(_to_bytes(i))
+      # TODO(laigd): use another key as the outputs are really not train_op.
+      grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+          output_collection)
 
   # Create RewriterConfig.
-  rewriter_cfg = rewriter_config_pb2.RewriterConfig()
-  rewriter_cfg.optimizers.extend(["constfold", "layout"])
-  optimizer = rewriter_cfg.custom_optimizers.add()
-  optimizer.name = "TensorRTOptimizer"
-  optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
-  optimizer.parameter_map["max_batch_size"].i = max_batch_size
-  optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
-  optimizer.parameter_map[
-      "max_workspace_size_bytes"].i = max_workspace_size_bytes
-  optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
-  optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
-  if cached_engine_batches:
-    if not isinstance(cached_engine_batches, list):
-      raise TypeError("cached_engine_batches should be a list.")
-    optimizer.parameter_map["cached_engine_batches"].list.i.extend(
-        cached_engine_batches)
+  rewriter_cfg = tensorrt_rewriter_config(
+      max_batch_size, max_workspace_size_bytes, precision_mode,
+      minimum_segment_size, is_dynamic_op, maximum_cached_engines,
+      cached_engine_batch_sizes)
 
-  return tf_optimizer.OptimizeGraph(
-      rewriter_cfg, meta_graph, graph_id=b"tf_graph")
+  # Run Grappler.
+  transformed_graph_def = tf_optimizer.OptimizeGraph(
+      rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph")
+
+  # Optionally write the transformed graphdef as SavedModel.
+  if output_saved_model_dir is not None:
+    saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+    with ops.Graph().as_default():
+      importer.import_graph_def(transformed_graph_def, name="")
+      with session.Session(config=session_config) as sess:
+        saved_model_builder.add_meta_graph_and_variables(
+            sess,
+            input_saved_model_tags,
+            signature_def_map=grappler_meta_graph_def.signature_def)
+    # Ignore other meta graphs from the input SavedModel.
+    saved_model_builder.save()
+
+  return transformed_graph_def
 
 
 def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
@@ -164,22 +338,13 @@
   Args:
     calibration_graph_def: the calibration GraphDef object with calibration data
     is_dynamic_op: whether to create dynamic static engines from calibration
+
   Returns:
     New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
   Raises:
     RuntimeError: if the returned status message is malformed.
   """
 
-  def py2string(inp):
-    return inp
-
-  def py3string(inp):
-    return inp.decode("utf-8")
-
-  if _six.PY2:
-    to_string = py2string
-  else:
-    to_string = py3string
   is_calib_graph = False
   for n in calibration_graph_def.node:
     if n.op == "TRTEngineOp":
@@ -190,7 +355,7 @@
     return None
   graph_str = calibration_graph_def.SerializeToString()
   out = calib_convert(graph_str, is_dynamic_op)
-  status = to_string(out[0])
+  status = _to_string(out[0])
   output_graph_def_string = out[1]
   del graph_str  # Save some memory
   if len(status) < 2:
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
new file mode 100644
index 0000000..118a668
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+# pylint: disable=unused-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+# pylint: enable=unused-import
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
+from tensorflow.python.tools import saved_model_utils
+
+
+class TrtConvertTest(test_util.TensorFlowTestCase):
+  """Class to test Tensorflow-TensorRT integration python API."""
+
+  def testTensorrtRewriterConfig(self):
+    """Test case for trt_convert.tensorrt_rewriter_config()."""
+    rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+        max_batch_size=128,
+        max_workspace_size_bytes=1234,
+        precision_mode="INT8",
+        minimum_segment_size=10,
+        is_dynamic_op=True,
+        maximum_cached_engines=2,
+        cached_engine_batch_sizes=[1, 128])
+    trt_optimizer = None
+    for optimizer in rewriter_cfg.custom_optimizers:
+      if optimizer.name == "TensorRTOptimizer":
+        self.assertTrue(trt_optimizer is None)
+        trt_optimizer = optimizer
+    self.assertTrue(trt_optimizer is not None)
+    for key in [
+        "minimum_segment_size", "max_batch_size", "is_dynamic_op",
+        "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
+        "cached_engine_batches"
+    ]:
+      self.assertTrue(key in trt_optimizer.parameter_map)
+    self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
+    self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
+    self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
+    self.assertEqual(1234,
+                     trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
+    self.assertEqual(
+        trt_convert._to_bytes("INT8"),
+        trt_optimizer.parameter_map["precision_mode"].s)
+    self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
+    self.assertEqual(
+        [1, 128],
+        trt_optimizer.parameter_map["cached_engine_batches"].list.i)
+
+  def _GetConfigProto(self):
+    """Get ConfigProto for session creation."""
+    config = config_pb2.ConfigProto(
+        gpu_options=config_pb2.GPUOptions(allow_growth=True))
+    return config
+
+  def _GetGraph(self):
+    """Get the graph for testing."""
+    g = ops.Graph()
+    with g.as_default():
+      with g.device("/GPU:0"):
+        inp = array_ops.placeholder(
+            dtype=dtypes.float32, shape=[None, 1, 1], name="input")
+        var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+        add = inp + var.value()
+        mul = inp * add
+        add = mul + add
+        out = array_ops.identity(add, name="output")
+    return g, var, inp, out
+
+  def _GetGraphDef(self):
+    """Get the graph def for testing."""
+    g, var, _, _ = self._GetGraph()
+    with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+      sess.run(var.initializer)
+      graph_def = graph_util.convert_variables_to_constants(
+          sess, g.as_graph_def(add_shapes=True), ["output"])
+    node_name_to_op = {node.name: node.op for node in graph_def.node}
+    self.assertEqual({
+        "v1": "Const",
+        "v1/read": "Identity",
+        "input": "Placeholder",
+        "add": "Add",
+        "mul": "Mul",
+        "add_1": "Add",
+        "output": "Identity"
+    }, node_name_to_op)
+    return graph_def
+
+  def _WriteInputSavedModel(self, input_saved_model_dir):
+    """Write the saved model as an input for testing."""
+    g, var, inp, out = self._GetGraph()
+    signature_def = signature_def_utils.build_signature_def(
+        inputs={"myinput": utils.build_tensor_info(inp)},
+        outputs={"myoutput": utils.build_tensor_info(out)},
+        method_name=signature_constants.PREDICT_METHOD_NAME)
+    saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
+    with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+      sess.run(var.initializer)
+      saved_model_builder.add_meta_graph_and_variables(
+          sess, [tag_constants.SERVING],
+          signature_def_map={"mypredict": signature_def})
+    saved_model_builder.save()
+
+  def _TestCreateInferenceGraph(self,
+                                input_saved_model_dir=None,
+                                output_saved_model_dir=None):
+    """General method to test trt_convert.create_inference_graph()."""
+    input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
+    output_graph_def = trt_convert.create_inference_graph(
+        input_graph_def, ["output"],
+        input_saved_model_dir=input_saved_model_dir,
+        output_saved_model_dir=output_saved_model_dir,
+        session_config=self._GetConfigProto())
+    graph_defs_to_verify = [output_graph_def]
+    if output_saved_model_dir is not None:
+      saved_model_graph_def = saved_model_utils.get_meta_graph_def(
+          output_saved_model_dir, tag_constants.SERVING).graph_def
+      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
+      graph_defs_to_verify.append(saved_model_graph_def)
+
+    for graph_def in graph_defs_to_verify:
+      node_name_to_op = {node.name: node.op for node in graph_def.node}
+      self.assertEqual({
+          "input": "Placeholder",
+          "my_trt_op_0": "TRTEngineOp",
+          "output": "Identity"
+      }, node_name_to_op)
+
+  def testCreateInferenceGraph_BasicConversion(self):
+    """Test case for trt_convert.create_inference_graph()."""
+    if not trt_convert.is_tensorrt_enabled():
+      return
+
+    # Use GraphDef as input.
+    self._TestCreateInferenceGraph()
+
+    # Use SavedModel as input.
+    tmp_dir = self.get_temp_dir()
+    input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
+    output_saved_model_dir = os.path.join(tmp_dir, "out_dir1")
+    self._WriteInputSavedModel(input_saved_model_dir)
+    self._TestCreateInferenceGraph(input_saved_model_dir,
+                                   output_saved_model_dir)
+
+  def _TestRun(self, sess, batch_size, expect_engine_is_run):
+    trt_convert.clear_test_values("")
+    result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
+    self.assertAllEqual([[[4.0]]] * batch_size, result)
+    execute_engine_test_value = ("done" if expect_engine_is_run else "")
+    execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
+    self.assertEqual(execute_engine_test_value,
+                     trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine"))
+    self.assertEqual(
+        execute_native_segment_test_value,
+        trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment"))
+
+  def testCreateInferenceGraph_MinimumSegmentSize(self):
+    if not trt_convert.is_tensorrt_enabled():
+      return
+    output_graph_def = trt_convert.create_inference_graph(
+        self._GetGraphDef(), ["output"],
+        minimum_segment_size=5,
+        is_dynamic_op=False)
+    node_name_to_op = {node.name: node.op for node in output_graph_def.node}
+    self.assertEqual({
+        "v1/read": "Const",
+        "input": "Placeholder",
+        "add": "Add",
+        "mul": "Mul",
+        "add_1": "Add",
+        "output": "Identity"
+    }, node_name_to_op)
+
+  def testCreateInferenceGraph_DynamicOp(self):
+    if not trt_convert.is_tensorrt_enabled():
+      return
+    trt_convert.enable_test_value()
+
+    tmp_dir = self.get_temp_dir()
+    input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
+    output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
+    self._WriteInputSavedModel(input_saved_model_dir)
+    output_graph_def = trt_convert.create_inference_graph(
+        None,
+        None,
+        is_dynamic_op=True,
+        maximum_cached_engines=2,
+        input_saved_model_dir=input_saved_model_dir,
+        output_saved_model_dir=output_saved_model_dir,
+        session_config=self._GetConfigProto())
+
+    # Test the output GraphDef.
+    with ops.Graph().as_default():
+      importer.import_graph_def(output_graph_def, name="")
+      with self.test_session(config=self._GetConfigProto()) as sess:
+        # Run with batch size 1, a new engine is created and cached.
+        self._TestRun(sess, 1, True)
+        # Run with batch size 2, a new engine is created and cached.
+        self._TestRun(sess, 2, True)
+        # Run with batch size 3, since the number of cached engines has reached
+        # the max, it should fall back to TF function.
+        self._TestRun(sess, 3, False)
+
+    # Test the output SavedModel
+    with ops.Graph().as_default():
+      with self.test_session(config=self._GetConfigProto()) as sess:
+        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+        # Run with batch size 1, a new engine is created and cached.
+        self._TestRun(sess, 1, True)
+        # Run with batch size 2, a new engine is created and cached.
+        self._TestRun(sess, 2, True)
+        # Run with batch size 3, since the number of cached engines has reached
+        # the max, it should fall back to TF function.
+        self._TestRun(sess, 3, False)
+
+  def testCreateInferenceGraph_StaticOp(self):
+    if not trt_convert.is_tensorrt_enabled():
+      return
+    trt_convert.enable_test_value()
+
+    tmp_dir = self.get_temp_dir()
+    input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
+    output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
+    self._WriteInputSavedModel(input_saved_model_dir)
+    output_graph_def = trt_convert.create_inference_graph(
+        None,
+        None,
+        max_batch_size=1,
+        is_dynamic_op=False,
+        maximum_cached_engines=2,  # This is noop, added just for testing.
+        input_saved_model_dir=input_saved_model_dir,
+        output_saved_model_dir=output_saved_model_dir,
+        session_config=self._GetConfigProto())
+
+    # Test the output GraphDef.
+    with ops.Graph().as_default():
+      importer.import_graph_def(output_graph_def, name="")
+      with self.test_session(config=self._GetConfigProto()) as sess:
+        # Run with batch size 1, the default engine embedded in the graphdef
+        # will be used.
+        self._TestRun(sess, 1, True)
+        # Run with batch size 2, which exceed the max_batch_size, it should fall
+        # back to TF function.
+        self._TestRun(sess, 2, False)
+
+    # Test the output SavedModel
+    with ops.Graph().as_default():
+      with self.test_session(config=self._GetConfigProto()) as sess:
+        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+        # Run with batch size 1, the default engine embedded in the graphdef
+        # will be used.
+        self._TestRun(sess, 1, True)
+        # Run with batch size 2, which exceed the max_batch_size, it should fall
+        # back to TF function.
+        self._TestRun(sess, 2, False)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 090aa8b..d26f260 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -191,7 +191,7 @@
       minimum_segment_size=2,  # minimum number of nodes in an engine
       is_dynamic_op=False,
       maximum_cached_engines=1,
-      cached_engine_batches=[])
+      cached_engine_batch_sizes=[])
   o1 = run_graph(orig_graph, dummy_input)
   o2 = run_graph(trt_graph, dummy_input)
   o3 = run_graph(trt_graph, dummy_input)
@@ -206,7 +206,7 @@
       minimum_segment_size=2,  # minimum number of nodes in an engine
       is_dynamic_op=False,
       maximum_cached_engines=1,
-      cached_engine_batches=[])
+      cached_engine_batch_sizes=[])
   int8_calib_gdef = trt.create_inference_graph(
       input_graph_def=orig_graph,
       outputs=["output"],
@@ -216,7 +216,7 @@
       minimum_segment_size=2,  # minimum number of nodes in an engine
       is_dynamic_op=False,
       maximum_cached_engines=1,
-      cached_engine_batches=[])
+      cached_engine_batch_sizes=[])
   o4 = run_graph(fp16_graph, dummy_input)
   _ = run_calibration(int8_calib_gdef, dummy_input)
   int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 65ca21c..fc647e4 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,7 +30,6 @@
 from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
 # pylint: enable=unused-import
 from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import importer
@@ -50,7 +49,7 @@
 ConversionParams = namedtuple("ConversionParams", [
     "max_batch_size", "max_workspace_size_bytes", "precision_mode",
     "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
-    "cached_engine_batches"
+    "cached_engine_batch_sizes"
 ])
 
 PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -139,7 +138,7 @@
         minimum_segment_size=2,
         is_dynamic_op=run_params.dynamic_engine,
         maximum_cached_engines=1,
-        cached_engine_batches=None)
+        cached_engine_batch_sizes=None)
 
   def ShouldRunTest(self, run_params):
     """Whether to run the test."""
@@ -201,23 +200,12 @@
   def _GetConfigProto(self, run_params, graph_state):
     """Get config proto based on specific settings."""
     if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
-      rewriter_cfg = rewriter_config_pb2.RewriterConfig()
-      rewriter_cfg.optimizers.extend(["constfold", "layout"])
-      custom_op = rewriter_cfg.custom_optimizers.add()
-      custom_op.name = "TensorRTOptimizer"
       trt_params = self.GetConversionParams(run_params)
-      custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
-      custom_op.parameter_map["max_workspace_size_bytes"].i = (
-          trt_params.max_workspace_size_bytes)
-      custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
-      custom_op.parameter_map["minimum_segment_size"].i = (
-          trt_params.minimum_segment_size)
-      custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
-      custom_op.parameter_map["maximum_cached_engines"].i = (
-          trt_params.maximum_cached_engines)
-      if trt_params.cached_engine_batches:
-        custom_op.parameter_map["cached_engine_batches"].list.i.extend(
-            trt_params.cached_engine_batches)
+      rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+          trt_params.max_batch_size, trt_params.max_workspace_size_bytes,
+          trt_params.precision_mode, trt_params.minimum_segment_size,
+          trt_params.is_dynamic_op, trt_params.maximum_cached_engines,
+          trt_params.cached_engine_batch_sizes)
 
       graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
     else:
@@ -308,7 +296,7 @@
         minimum_segment_size=trt_params.minimum_segment_size,
         is_dynamic_op=trt_params.is_dynamic_op,
         maximum_cached_engines=trt_params.maximum_cached_engines,
-        cached_engine_batches=trt_params.cached_engine_batches)
+        cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes)
 
   def _WriteGraph(self, run_params, gdef, graph_state):
     if graph_state == GraphState.ORIGINAL:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index d808945..1d27fff 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -264,10 +264,10 @@
     elif (not isinstance(periodicities, list) and
           not isinstance(periodicities, tuple)):
       periodicities = [periodicities]
-    self._periods = [int(p) for p in periodicities]
-    for p in self._periods:
+    self._periodicities = [int(p) for p in periodicities]
+    for p in self._periodicities:
       assert p > 0
-    assert len(self._periods) or self.input_window_size
+    assert len(self._periodicities) or self.input_window_size
     assert output_window_size > 0
 
   def initialize_graph(self, input_statistics=None):
@@ -364,9 +364,9 @@
     input_feature_size = 0
     output_window_features = []
     output_feature_size = 0
-    if self._periods:
+    if self._periodicities:
       _, time_features = self._compute_time_features(times)
-      num_time_features = self._buckets * len(self._periods)
+      num_time_features = self._buckets * len(self._periodicities)
       time_features = array_ops.reshape(
           time_features,
           [batch_size,
@@ -849,12 +849,12 @@
   def _compute_time_features(self, time):
     """Compute some features on the time value."""
     batch_size = array_ops.shape(time)[0]
-    num_periods = len(self._periods)
+    num_periods = len(self._periodicities)
     # Reshape to 3D.
     periods = constant_op.constant(
-        self._periods, shape=[1, 1, num_periods, 1], dtype=time.dtype)
+        self._periodicities, shape=[1, 1, num_periods, 1], dtype=time.dtype)
     time = array_ops.reshape(time, [batch_size, -1, 1, 1])
-    window_offset = time / self._periods
+    window_offset = time / self._periodicities
     # Cast to appropriate type and scale to [0, 1) range
     mod = (math_ops.cast(time % periods, self.dtype) * self._buckets /
            math_ops.cast(periods, self.dtype))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 461fe22..83260fc 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -216,6 +216,15 @@
           exogenous_feature_columns=exogenous_feature_columns)
     self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
 
+  def test_structural_ensemble_numpy_input(self):
+    numpy_data = {"times": numpy.arange(50),
+                  "values": numpy.random.normal(size=[50])}
+    estimators.StructuralEnsembleRegressor(
+        num_features=1, periodicities=[], model_dir=self.get_temp_dir(),
+        config=_SeedRunConfig()).train(
+            input_pipeline.WholeDatasetInputFn(
+                input_pipeline.NumpyReader(numpy_data)),
+            steps=1)
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index e65e7b7..647455ae 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -122,7 +122,7 @@
           metric[1] for metric in outputs.eval_metric_ops.values()]
       loss_mean, loss_update = metrics.mean(outputs.loss)
       metric_update_ops.append(loss_update)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         coordinator = coordinator_lib.Coordinator()
         queue_runner_impl.start_queue_runners(sess, coord=coordinator)
         variables.local_variables_initializer().run()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
index 703537a..f92148b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
@@ -88,7 +88,7 @@
         window_size=window_size, batch_size=batch_size)
     result, _ = input_fn()
     init_op = variables.local_variables_initializer()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
       session.run(init_op)
@@ -261,7 +261,7 @@
   def _whole_dataset_input_fn_test_template(
       self, time_series_reader, num_features, num_samples):
     result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       session.run(variables.local_variables_initializer())
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -340,7 +340,7 @@
         window_size=window_size)
     features, _ = input_fn()
     init_op = variables.local_variables_initializer()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
       session.run(init_op)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 9b593fe..03da2b8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -896,8 +896,8 @@
           statistics.total_observation_count,
           math_ops.cast(
               gen_math_ops.round(
-                  math_ops.cast(auxiliary_variables.max_time_seen -
-                                statistics.start_time + 1, self._dtype) /
+                  math_ops.cast(max_time_seen_assign -
+                                start_time_update + 1, self._dtype) /
                   inter_observation_duration_estimate), dtypes.int64))
       per_chunk_stat_updates = control_flow_ops.group(
           overall_feature_mean_update, overall_feature_var_update,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index 02d2524..c0de42b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -55,7 +55,7 @@
       running_sum = running_sum + current_contribution
       # pylint: enable=g-no-augmented-assignment
       transition_power = numpy.dot(transition, transition_power)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(result,
                           math_utils.power_sums_tensor(
                               array_size, transition, addition).eval())
@@ -66,7 +66,7 @@
     result = []
     for i in range(powers.shape[0]):
       result.append(numpy.linalg.matrix_power(matrix, powers[i]))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(result,
                           math_utils.matrix_to_powers(matrix, powers).eval(),
                           rtol=1e-5,
@@ -78,7 +78,7 @@
     result = []
     for i in range(batch.shape[0]):
       result.append(numpy.linalg.matrix_power(batch[i], powers[i]))
-    with self.test_session():
+    with self.cached_session():
       # TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be
       # made slightly more stable?
       self.assertAllClose(result,
@@ -91,7 +91,7 @@
     left_transpose = numpy.transpose(left, [0, 2, 1])
     right = numpy.random.normal(size=[2, 3]).astype(numpy.float32)
     expected_result = numpy.dot(left, right)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(expected_result,
                           math_utils.batch_times_matrix(
                               left, right).eval())
@@ -114,7 +114,7 @@
     right_transpose = numpy.transpose(right, [0, 2, 1])
     expected_result = numpy.transpose(numpy.dot(right_transpose, left.T),
                                       [0, 2, 1])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(expected_result,
                           math_utils.matrix_times_batch(
                               left, right).eval())
@@ -132,7 +132,7 @@
                               adj_x=True, adj_y=True).eval())
 
   def test_make_diagonal_undefined_shapes(self):
-    with self.test_session():
+    with self.cached_session():
       completely_undefined = array_ops.placeholder(dtype=dtypes.float32)
       partly_undefined = array_ops.placeholder(
           shape=[None, None], dtype=dtypes.float32)
@@ -152,7 +152,7 @@
                                  [5., 6.]]}))
 
   def test_make_diagonal_mostly_defined_shapes(self):
-    with self.test_session():
+    with self.cached_session():
       mostly_defined = array_ops.placeholder(
           shape=[None, 2], dtype=dtypes.float32)
       blocked = math_utils.block_diagonal([[[2.]],
@@ -192,7 +192,7 @@
 
   def _test_make_toeplitz_matrix(self, inputs, output_expected):
     output_tf = math_utils.make_toeplitz_matrix(inputs)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output_tf_np = sess.run(output_tf)
     self.assertAllClose(output_tf_np, output_expected)
 
@@ -201,13 +201,13 @@
 
   def test_zero_size_matrix(self):
     raw = numpy.zeros([0, 0])
-    with self.test_session():
+    with self.cached_session():
       constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval()
     self.assertEqual((0, 0), constructed.shape)
 
   def test_sign_magnitude_positive_definite(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         matrix_tensor = math_utils.sign_magnitude_positive_definite(
             raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype),
             off_diagonal_scale=constant_op.constant(-1., dtype=dtype),
@@ -230,7 +230,8 @@
         name="test_lookup")
     def stack_tensor(base_tensor):
       return array_ops.stack([base_tensor + 1, base_tensor + 2])
-    with self.test_session() as session:
+
+    with self.cached_session() as session:
       ((float_output, double_output), int_output) = session.run(
           hash_table.lookup([2, 1, 0]))
       def expected_output_before_insert(base_tensor):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index cfd31cc..a049dbe7 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -29,7 +29,7 @@
   def test_parameter_switching(self):
     parameter = array_ops.constant(5)
     overridden_parameter = array_ops.constant(3)
-    with self.test_session():
+    with self.cached_session():
       getter = model_utils.parameter_switch({overridden_parameter: 4})
       self.assertEqual(5, getter(parameter))
       self.assertEqual(4, getter(overridden_parameter))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
index 5f7e3da..42ba6e1 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
@@ -127,7 +127,7 @@
     chainer.initialize_graph(model=stub_model)
     model_outputs = chainer.define_loss(
         model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       variables.global_variables_initializer().run()
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -178,7 +178,7 @@
     result_model_outputs = chainer.define_loss(
         model=stub_model, features=result_input_fn()[0],
         mode=estimator_lib.ModeKeys.TRAIN)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       variables.global_variables_initializer().run()
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -221,7 +221,7 @@
     chainer.initialize_graph(model=stub_model)
     model_outputs = chainer.define_loss(
         model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       variables.global_variables_initializer().run()
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
index 53d7340..a77c507 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
@@ -61,7 +61,7 @@
       expected_state = [[[80.], [20.]],
                         [1., 6.],
                         [-1, -2]]
-      with self.test_session():
+      with self.cached_session():
         for interpolated, expected in zip(interpolated_state, expected_state):
           self.assertAllClose(expected, interpolated.eval())
         self.assertGreater(0., updated_outputs["anomaly_score"][0].eval())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
index 57f29f3..f636126 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
@@ -98,7 +98,7 @@
         observation_model=observation_model,
         predicted_observations=(observed_mean, observed_var),
         observation_noise=observation_noise_covariance)
-    with self.test_session() as session:
+    with self.cached_session() as session:
       evaled_state = numpy.array([[1., 1., 1., 1.]])
       evaled_state_var = numpy.eye(4)[None]
       for i in range(500):
@@ -136,7 +136,7 @@
 
   def test_observed_from_state(self):
     """Compare observation mean and noise to hand-computed values."""
-    with self.test_session():
+    with self.cached_session():
       state = constant_op.constant([[2., 1.]])
       state_var = constant_op.constant([[[4., 0.], [0., 3.]]])
       observed_mean, observed_var = self.kalman_filter.observed_from_state(
@@ -171,7 +171,7 @@
             observation_model=observation_model,
             predicted_observations=predicted_observations,
             observation_noise=observation_noise))
-    with self.test_session() as session:
+    with self.cached_session() as session:
       evaled_state, evaled_state_var = session.run([state, state_var])
       for _ in range(300):
         evaled_state, evaled_state_var = session.run(
@@ -231,7 +231,7 @@
 
   def test_predict_state_mean(self):
     """Compare state mean transitions with simple hand-computed values."""
-    with self.test_session():
+    with self.cached_session():
       state = constant_op.constant([[4., 2.]])
       state = self.kalman_filter.predict_state_mean(
           state, self.transition_fn([1]))
@@ -245,7 +245,7 @@
 
   def test_predict_state_var(self):
     """Compare a variance transition with simple hand-computed values."""
-    with self.test_session():
+    with self.cached_session():
       state_var = constant_op.constant([[[1., 0.], [0., 2.]]])
       state_var = self.kalman_filter.predict_state_var(
           state_var, self.transition_fn([1]), self.power_sum_fn([1]))
@@ -259,7 +259,7 @@
     Tests that correct values have high probability and incorrect values
     have low probability when there is low uncertainty.
     """
-    with self.test_session():
+    with self.cached_session():
       state = constant_op.constant([[4., 2.]])
       state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]])
       observation = constant_op.constant([[
@@ -289,7 +289,7 @@
       self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99))
 
   def test_predict_n_ahead_mean(self):
-    with self.test_session():
+    with self.cached_session():
       original_state = constant_op.constant([[4., 2.]])
       n = 5
       iterative_state = original_state
@@ -304,7 +304,7 @@
             self.transition_fn([1]))
 
   def test_predict_n_ahead_var(self):
-    with self.test_session():
+    with self.cached_session():
       original_var = constant_op.constant([[[2., 3.], [4., 5.]]])
       n = 5
       iterative_var = original_var
@@ -330,7 +330,7 @@
     Tests that correct values have high probability and incorrect values
     have low probability when there is low uncertainty.
     """
-    with self.test_session():
+    with self.cached_session():
       state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]])
       state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]])
       observation = constant_op.constant([
@@ -378,7 +378,7 @@
       self.assertLess(third_log_prob.sum(), numpy.log(0.01))
 
   def test_predict_n_ahead_mean(self):
-    with self.test_session():
+    with self.cached_session():
       kf = kalman_filter.KalmanFilter()
       transition_fn, _ = _powers_and_sums_from_transition_matrix(
           state_transition=STATE_TRANSITION,
@@ -396,7 +396,7 @@
       self.assertAllClose(state2.eval()[2], batch_eval[2])
 
   def test_predict_n_ahead_var(self):
-    with self.test_session():
+    with self.cached_session():
       kf = kalman_filter.KalmanFilter()
       transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
           state_transition=STATE_TRANSITION,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index c2eaa78..80126ac 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -96,7 +96,7 @@
           },
           mode=estimator_lib.ModeKeys.TRAIN)
       initializer = variables.global_variables_initializer()
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run([initializer])
         outputs.loss.eval()
 
@@ -114,7 +114,7 @@
           },
           mode=estimator_lib.ModeKeys.TRAIN)
       initializer = variables.global_variables_initializer()
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run([initializer])
         outputs.loss.eval()
 
@@ -144,7 +144,7 @@
         state=math_utils.replicate_state(
             start_state=random_model.get_start_state(),
             batch_size=array_ops.shape(times)[0]))
-    with self.test_session() as session:
+    with self.cached_session() as session:
       variables.global_variables_initializer().run()
       coordinator = coordinator_lib.Coordinator()
       queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -250,7 +250,7 @@
       self.assertAllClose(combined_value, split_predict[prediction_key])
 
   def _equivalent_to_single_model_test_template(self, model_generator):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       random_model = RandomStateSpaceModel(
           state_dimension=5,
           state_noise_dimension=4,
@@ -374,7 +374,7 @@
               math_utils.replicate_state(
                   start_state=random_model.get_start_state(), batch_size=1)
       })
-      with self.test_session():
+      with self.cached_session():
         variables.global_variables_initializer().run()
         predicted_mean = prediction_dict["mean"].eval()
         predicted_covariance = prediction_dict["covariance"].eval()
@@ -404,7 +404,7 @@
           feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
           feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
       })
-      with self.test_session():
+      with self.cached_session():
         variables.global_variables_initializer().run()
         predicted_mean = predictions["mean"].eval()
         predicted_covariance = predictions["covariance"].eval()
@@ -428,7 +428,7 @@
             state=[
                 array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
             ])
-        with self.test_session() as session:
+        with self.cached_session() as session:
           variables.global_variables_initializer().run()
           evaled_new_covariance, evaled_original_covariance = session.run(
               [new_covariance[0], original_covariance])
@@ -454,7 +454,7 @@
                 -array_ops.ones(shape=[1, 5], dtype=dtype),
                 original_covariance[None], [0]
             ])
-        with self.test_session() as session:
+        with self.cached_session() as session:
           variables.global_variables_initializer().run()
           evaled_new_covariance, evaled_original_covariance = session.run(
               [new_covariance[0], original_covariance])
@@ -519,7 +519,7 @@
         model=stub_model, data=data, true_parameters=true_params)
 
   def test_exact_posterior_recovery_no_transition_noise(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       stub_model, data, true_params = self._get_single_model()
       input_fn = input_pipeline.WholeDatasetInputFn(
           input_pipeline.NumpyReader(data))
@@ -559,7 +559,7 @@
           posterior_times)
 
   def test_chained_exact_posterior_recovery_no_transition_noise(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       stub_model, data, true_params = self._get_single_model()
       chunk_size = 10
       input_fn = test_utils.AllWindowInputFn(
@@ -748,7 +748,7 @@
         },
         mode=estimator_lib.ModeKeys.TRAIN)
     initializer = variables.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([initializer])
       outputs.loss.eval()
 
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
index 84885d5..e8875f4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
@@ -46,7 +46,7 @@
         },
         mode=estimator_lib.ModeKeys.TRAIN)
     initializer = variables.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([initializer])
       outputs.loss.eval()
 
@@ -65,7 +65,7 @@
         },
         mode=estimator_lib.ModeKeys.TRAIN)
     initializer = variables.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([initializer])
       outputs.loss.eval()
 
@@ -85,7 +85,7 @@
             TrainEvalFeatures.VALUES: constant_op.constant([[[1.], [2.]]])},
         mode=estimator_lib.ModeKeys.TRAIN)
     initializer = variables.global_variables_initializer()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([initializer])
       outputs.loss.eval()
 
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 537d94b..3c0456d 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -33,6 +33,7 @@
 @@shard
 @@batch_parallel
 @@rewrite
+@@outside_compilation
 
 @@CrossShardOptimizer
 
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 9ee5ecb..ea8e0e0 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -18,6 +18,89 @@
 #include "tensorflow/core/framework/shape_inference.h"
 
 namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_OP("AllToAll")
+    .Input("input: T")
+    .Input("group_assignment: int32")
+    .Output("output: T")
+    .Attr("T: {bfloat16, float}")
+    .Attr("concat_dimension: int")
+    .Attr("split_dimension: int")
+    .Attr("split_count: int")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle input = c->input(0);
+      int64 rank;
+      if (c->RankKnown(input)) {
+        rank = c->Rank(input);
+      } else {
+        return errors::InvalidArgument("input's rank is unknown.");
+      }
+      int concat_dimension;
+      int split_dimension;
+
+      TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
+
+      if (concat_dimension < 0 || concat_dimension >= rank) {
+        return errors::InvalidArgument("concat_dimension ", concat_dimension,
+                                       " is out of range of input rank ", rank);
+      }
+
+      TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension));
+      if (split_dimension < 0 || split_dimension >= rank) {
+        return errors::InvalidArgument("split_dimension ", split_dimension,
+                                       " is out of range of input rank ", rank);
+      }
+
+      std::vector<DimensionHandle> dims;
+      dims.resize(rank);
+
+      for (int32 i = 0; i < rank; ++i) {
+        int64 in_idx = i;
+        if (i == concat_dimension) {
+          in_idx = split_dimension;
+        } else if (i == split_dimension) {
+          in_idx = concat_dimension;
+        }
+
+        dims[i] = c->Dim(input, in_idx);
+      }
+
+      c->set_output(0, c->MakeShape(dims));
+      return Status::OK();
+    })
+    .Doc(R"doc(
+An Op to exchange data across TPU replicas. On each replica, the input is
+split into `split_count` blocks along `split_dimension` and send to the other
+replicas given group_assignment. After receiving `split_count` - 1 blocks from
+other replicas, we concatenate the blocks along `concat_dimension` as the
+output.
+
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+
+group_assignment=`[[0, 1]]`
+concat_dimension=0
+split_dimension=1
+split_count=2
+
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+
+input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+  [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+  replica ids in the ith subgroup.
+concat_dimension: The dimension number to concatenate.
+split_dimension: The dimension number to split.
+split_count: The number of splits, this number must equal to the sub-group
+  size(group_assignment.get_shape()[1])
+output: The exchanged result.
+T: The type of elements to be exchanged.
+)doc");
 
 REGISTER_OP("CrossReplicaSum")
     .Input("input: T")
@@ -26,10 +109,8 @@
     .Attr("T: {bfloat16, float}")
     .SetShapeFn(shape_inference::UnchangedShape)
     .Doc(R"doc(
-An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input. If group_assignment is empty, the output of
-each is the sum of all the inputs, otherwise the output of each is the sum of
-the inputs belonging to the same group.
+An Op to sum inputs across replicated TPU instances. Each instance supplies its
+own input.
 
 For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
 Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index 98cc31f..b4b06a4 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -142,9 +142,8 @@
     TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
                                                response.encoded_trace(), os));
   }
-  if (response.has_op_profile() &&
-      (response.op_profile().has_by_program_structure() ||
-       response.op_profile().has_by_category())) {
+  if (response.has_op_profile() && (response.op_profile().has_by_program() ||
+                                    response.op_profile().has_by_category())) {
     TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
                                                    response.op_profile(), os));
   }
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index feb177a..68cf510 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -4,12 +4,14 @@
 
 // Profile is the top-level data that summarizes a program.
 message Profile {
+  reserved 2;
+  reserved "by_program_structure";
+  reserved 3;
+  reserved "per_program";
   // Root of a profile broken down by instruction category.
   Node by_category = 1;
-  // Root of a profile broken down by program structure.
-  Node by_program_structure = 2;
-  // Per program profile, indexed by hlo module name of the program.
-  map<string, Node> per_program = 3;
+  // Root of a profile broken down by program.
+  Node by_program = 4;
 }
 
 // An entry in the profile tree. (An instruction, or set of instructions).
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 2b13343..f88dc51 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -79,12 +79,15 @@
   // The step duration in picoseconds.
   optional uint64 duration_ps = 2;
   // The infeed duration in picoseconds.
-  // Can turn into a map if we want a variable number of ops.
   optional uint64 infeed_duration_ps = 3;
+  // The outfeed duration in picoseconds.
+  optional uint64 host_outfeed_ps = 8;
   // The start time of this step in picoseconds.
   optional uint64 begin_ps = 4;
   // The waiting time within this step in picoseconds.
   optional uint64 wait_duration_ps = 5;
+  // The unit b outfeed duration in picoseconds.
+  optional uint64 unit_b_outfeed_ps = 9;
   // The time spent on cross-replica-sum in picoseconds.
   optional uint64 crs_duration_ps = 6;
   // Percentage of unit b time spent on infeed.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index bf807af6..fc13205 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -9,8 +9,8 @@
   google.protobuf.FloatValue upper = 2;  // +inf if not set
 }
 
-// Get the learning rate from a <yet to be determined> source that can change
-// dynamically.
+// Get the learning rate from the parameters of the SendTPUEmbeddingGradients
+// op.
 message DynamicLearningRate {
 }
 
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 3ed571a..d92a065 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -38,6 +38,62 @@
   _tpu_ops = loader.load_op_library(
       resource_loader.get_path_to_datafile("_tpu_ops.so"))
 
+  def _create_default_group_assignment():
+    num_shards = tpu_function.get_tpu_context().number_of_shards
+    if num_shards is None:
+      logging.warning(
+          "cross_replica_sum should be used within a tpu_shard_context, but "
+          "got unset number_of_shards. Assuming 1.")
+      num_shards = 1
+    group_assignment = [list(range(num_shards))]
+    return group_assignment
+
+  def all_to_all(x,
+                 concat_dimension,
+                 split_dimension,
+                 split_count,
+                 group_assignment=None,
+                 name=None):
+    """Exchange data across TPU replicas.
+
+    Args:
+      x: The local tensor.
+      concat_dimension: The dimension number to concatenate.
+      split_dimension: The dimension number to split.
+      split_count: The number of splits, this number must equal to the sub-group
+        size(group_assignment.get_shape()[1])
+      group_assignment: Optional 2d int32 lists with shape [num_groups,
+        num_replicas_per_group]. `group_assignment[i]` represents the replica
+        ids in the ith subgroup.
+      name: Optional op name.
+
+    Returns:
+      A `Tensor` which is concatenated by data from different replicas.
+    """
+    if group_assignment is None:
+      group_assignment = _create_default_group_assignment()
+    return gen_tpu_ops.all_to_all(
+        x,
+        group_assignment,
+        concat_dimension=concat_dimension,
+        split_dimension=split_dimension,
+        split_count=split_count,
+        name=name)
+
+  @ops.RegisterGradient("AllToAll")
+  def _all_to_all_grad(op, grad):
+    # The gradient of a all-to-all is also a all-to-all but the
+    # split_dimension and concat_dimension is swapped.
+    # The graident with respect to group_assignment is None.
+    return [
+        gen_tpu_ops.all_to_all(
+            grad,
+            op.inputs[1],
+            concat_dimension=op.get_attr("split_dimension"),
+            split_dimension=op.get_attr("concat_dimension"),
+            split_count=op.get_attr("split_count")), None
+    ]
+
   def cross_replica_sum(x, group_assignment=None, name=None):
     """Sum the input tensor accorss replicas according to group_assignment.
 
@@ -52,13 +108,7 @@
       A `Tensor` which is summed across replicas.
     """
     if group_assignment is None:
-      num_shards = tpu_function.get_tpu_context().number_of_shards
-      if num_shards is None:
-        logging.warning(
-            "cross_replica_sum should be used within a tpu_shard_context, but "
-            "got unset number_of_shards. Assuming 1.")
-        num_shards = 1
-      group_assignment = [list(range(num_shards))]
+      group_assignment = _create_default_group_assignment()
 
     return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
 
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index ff88508..d8c3872 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -170,11 +170,41 @@
     worker_re = re.compile('/job:([^/]+)')
     for device in metadata.devices:
       if 'TPU:0' in device.name:
-        self.worker_name = worker_re.search(device.name).group(1)
+        self._worker_name = worker_re.search(device.name).group(1)
         break
 
+  def _make_assignment_for_model(self, cpu_model):
+    """Makes a `TPUAssignment` for the passed in `cpu_model`."""
+    num_cores = self._num_cores
+    if num_cores > 1 and cpu_model.stateful:
+      logging.warning(
+          'Model replication does not currently support stateful models.  '
+          'Degrading to a single core.')
+      num_cores = 1
+
+    return TPUAssignment(
+        worker_name=self._worker_name, num_cores=num_cores)
+
+
+class TPUAssignment(object):
+  """This is object holding TPU resources assignment for the concrete model.
+
+  `TPUDistributionStrategy` is responsible to create the instance of
+  `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
+  model and input batch sizes.
+  """
+
+  def __init__(self, worker_name, num_cores):
+    self._worker_name = worker_name
+    self._num_cores = num_cores
+
+  @property
+  def worker_name(self):
+    return self._worker_name
+
   @property
   def num_towers(self):
+    # TODO(xiejw): Support automatically assign num_cores based on inputs.
     return self._num_cores
 
 
@@ -228,6 +258,8 @@
     return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
 
   def set_weights(self, weights):
+    # TODO(power): Figure out whether we really need this given there is no
+    # caller for this API yet.
     self._opt.set_weights()
 
   def get_weights(self):
@@ -252,9 +284,9 @@
 
 def _replicated_optimizer(opt):
   """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
-  if tpu_function.get_tpu_context().number_of_shards == 1:
-    return opt
-
+  # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
+  # single core.  This ensures Keras properly tracks and initializes optimizer
+  # variables.
   if isinstance(opt, keras_optimizers.TFOptimizer):
     return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
   else:
@@ -495,8 +527,8 @@
           infeed_dict[tensor] = value
       return infeed_dict
 
-  def __init__(self, distribution_strategy):
-    self._strategy = distribution_strategy
+  def __init__(self, tpu_assignment):
+    self._tpu_assignment = tpu_assignment
 
   def _split_tensors(self, inputs):
     """Split input data across shards.
@@ -509,16 +541,16 @@
     Returns:
       List of lists containing the input to feed to each TPU shard.
     """
-    if self._strategy.num_towers == 1:
+    if self._tpu_assignment.num_towers == 1:
       return [inputs]
 
     batch_size = inputs[0].shape[0]
-    assert batch_size % self._strategy.num_towers == 0, (
-        'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
-        (batch_size, self._strategy.num_towers))
-    shard_size = batch_size // self._strategy.num_towers
+    assert batch_size % self._tpu_assignment.num_towers == 0, (
+        'batch_size must be divisible by the number of TPU cores in use (%s '
+        'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
+    shard_size = batch_size // self._tpu_assignment.num_towers
     input_list = []
-    for index in range(self._strategy.num_towers):
+    for index in range(self._tpu_assignment.num_towers):
       shard_inputs = [
           x[index * shard_size:(index + 1) * shard_size] for x in inputs
       ]
@@ -533,8 +565,9 @@
     infeed_op = []
     shard_infeed_tensors = []
 
-    for shard_id in range(self._strategy.num_towers):
-      with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+    for shard_id in range(self._tpu_assignment.num_towers):
+      with ops.device(
+          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
         infeed_tensors = []
         with ops.device('/device:TPU:%d' % shard_id):
           for spec in input_specs:
@@ -573,30 +606,31 @@
       # TODO(saeta): Verify tpu_model_op is as expected!
       return {}
 
-  def __init__(self, dataset, distribution_strategy, tpu_session):
+  # pylint: disable=redefined-outer-name
+  def __init__(self, dataset, tpu_assignment, tpu_session):
     """Constructs a TPUDatasetInfeedManager.
 
     Must be called within a `KerasTPUModel.tpu_session` context!
 
     Args:
       dataset: A `tf.data.Dataset` to infeed.
-      distribution_strategy: The `TPUDistributionStrategy` used to configure the
+      tpu_assignment: The `TPUAssignment` used to configure the
         Keras TPU model.
       tpu_session: The `tf.Session` object used for running the TPU model.
     """
     self._verify_dataset_shape(dataset)
     self._dataset = dataset
-    self._strategy = distribution_strategy
+    self._tpu_assignment = tpu_assignment
     dummy_x_shape = dataset.output_shapes[0].as_list()
-    dummy_x_shape[0] *= distribution_strategy.num_towers
+    dummy_x_shape[0] *= tpu_assignment.num_towers
     dummy_y_shape = dataset.output_shapes[1].as_list()
-    dummy_y_shape[0] *= distribution_strategy.num_towers
+    dummy_y_shape[0] *= tpu_assignment.num_towers
     self._iterator = dataset.make_initializable_iterator()
     tpu_session.run(self._iterator.initializer)
 
     self._get_next_ops = []
     ctrl_deps = []
-    for i in range(distribution_strategy.num_towers):
+    for i in range(tpu_assignment.num_towers):
       with ops.control_dependencies(ctrl_deps):  # Ensure deterministic
         # TODO(saeta): Ensure correct placement!
         get_next_op = self._iterator.get_next()
@@ -676,10 +710,11 @@
 
   def build_infeed_from_input_specs(self, input_specs, execution_mode):
     shard_infeed_tensors = self._get_next_ops
-    assert len(shard_infeed_tensors) == self._strategy.num_towers
+    assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
     infeed_ops = []
-    for shard_id in range(self._strategy.num_towers):
-      with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+    for shard_id in range(self._tpu_assignment.num_towers):
+      with ops.device(
+          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
         infeed_ops.append(
             tpu_ops.infeed_enqueue_tuple(
                 shard_infeed_tensors[shard_id],
@@ -702,10 +737,10 @@
   instead of being injected as `feed_dict` items or fetches.
   """
 
-  def __init__(self, model, execution_mode, strategy):
+  def __init__(self, model, execution_mode, tpu_assignment):
     self.model = model
     self.execution_mode = execution_mode
-    self._strategy = strategy
+    self._tpu_assignment = tpu_assignment
     self._compilation_cache = {}
     self._cloned_model = None
 
@@ -757,7 +792,8 @@
       # Clone our CPU model, running within the TPU device context.
       with TPURewriteContext(tpu_input_map):
         with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
-          with keras_tpu_variables.replicated_scope(self._strategy.num_towers):
+          with keras_tpu_variables.replicated_scope(
+              self._tpu_assignment.num_towers):
             self._cloned_model = models.clone_model(self.model)
 
       # Create a copy of the optimizer for this graph.
@@ -827,7 +863,7 @@
     # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
     # running on a different logical core.
     compile_op, execute_op = tpu.split_compile_and_replicate(
-        _model_fn, inputs=[[]] * self._strategy.num_towers)
+        _model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
 
     # Generate CPU side operations to enqueue features/labels and dequeue
     # outputs from the model call.
@@ -835,8 +871,9 @@
         input_specs, self.execution_mode)
     # Build output ops.
     outfeed_op = []
-    for shard_id in range(self._strategy.num_towers):
-      with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+    for shard_id in range(self._tpu_assignment.num_towers):
+      with ops.device(
+          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
         outfeed_op.extend(
             tpu_ops.outfeed_dequeue_tuple(
                 dtypes=[spec.dtype for spec in self._outfeed_spec],
@@ -886,7 +923,7 @@
     for x, mgr in self.model._numpy_to_infeed_manager_list:
       if inputs[0] is x:
         return mgr
-    return TPUNumpyInfeedManager(self.model._strategy)
+    return TPUNumpyInfeedManager(self.model._tpu_assignment)
 
   def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
     """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
@@ -958,7 +995,7 @@
       outputs = [[]] * len(self._outfeed_spec)
       outputs_per_replica = len(self._outfeed_spec)
 
-      for i in range(self._strategy.num_towers):
+      for i in range(self._tpu_assignment.num_towers):
         output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                        outputs_per_replica]
         for j in range(outputs_per_replica):
@@ -967,7 +1004,7 @@
       return [np.concatenate(group) for group in outputs]
     else:
       return outfeed_outputs[:len(outfeed_outputs) //
-                             self._strategy.num_towers]
+                             self._tpu_assignment.num_towers]
 
   def __call__(self, inputs):
     """__call__ executes the function on the computational hardware.
@@ -1119,11 +1156,11 @@
     self.predict_function = None
     self.test_function = None
     self.train_function = None
-    self._strategy = strategy
 
-    cluster_resolver = self._strategy._tpu_cluster_resolver
+    cluster_resolver = strategy._tpu_cluster_resolver
     self._tpu_name_or_address = cluster_resolver.get_master()
     self._cpu_model = cpu_model
+    self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
     self._tpu_model = None
     self._tpu_weights_initialized = False
 
@@ -1146,7 +1183,7 @@
     return {
         'cpu_model': self._cpu_model,
         'tpu_name_or_address': self._tpu_name_or_address,
-        'strategy': self._strategy,
+        'tpu_assignment': self._tpu_assignment,
     }
 
   def compile(self,
@@ -1207,7 +1244,7 @@
           '/keras')
     if callable(x):
       with self.tpu_session() as sess,\
-          ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
+          ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
         dataset = x()
         if steps_per_epoch is None:
           raise ValueError('When using tf.data as input to a model, you '
@@ -1215,7 +1252,8 @@
         if y is not None:
           raise ValueError('When using tf.data as input to a model, y must be '
                            'None')
-        infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+        infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+                                                 sess)
         # Use dummy numpy inputs for the rest of Keras' shape checking. We
         # intercept them when building the model.
         x = infeed_manager.dummy_x
@@ -1236,7 +1274,8 @@
         if validation_steps is None:
           raise ValueError('When using tf.data as validation for a model, you '
                            'should specify the validation_steps argument.')
-        infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+        infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+                                                 sess)
         # Use dummy numpy inputs for the rest of Keras' shape checking. We
         # intercept them when building the model.
         val_x = infeed_manager.dummy_x
@@ -1313,7 +1352,8 @@
         if y is not None:
           raise ValueError('When using tf.data as input to a model, y must be '
                            'None')
-        infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+        infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+                                                 sess)
         # Use dummy numpy inputs for the rest of Keras' shape checking. We
         # intercept them when building the model.
         x = infeed_manager.dummy_x
@@ -1382,7 +1422,7 @@
         y,
         sample_weights,
         batch_size)
-    self._pipeline_fit_loop(
+    return self._pipeline_fit_loop(
         x,
         y,
         sample_weights=sample_weights,
@@ -1619,7 +1659,7 @@
                       'make sure your paths are correct and you have '
                       'permissions to read the files. Skipping validation')
 
-    for step_index in range(steps_per_epoch - 1):
+    for step_index in range(steps_per_epoch):
       batch_logs = {'batch': step_index, 'size': 1}
       callbacks.on_batch_begin(step_index, batch_logs)
       try:
@@ -1740,20 +1780,24 @@
   def _make_train_function(self):
     if not self.train_function:
       self.train_function = TPUFunction(
-          self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
+          self,
+          model_fn_lib.ModeKeys.TRAIN,
+          tpu_assignment=self._tpu_assignment)
 
     return self.train_function
 
   def _make_test_function(self):
     if not self.test_function:
       self.test_function = TPUFunction(
-          self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
+          self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
     return self.test_function
 
   def _make_predict_function(self):
     if not self.predict_function:
       self.predict_function = TPUFunction(
-          self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
+          self,
+          model_fn_lib.ModeKeys.PREDICT,
+          tpu_assignment=self._tpu_assignment)
     return self.predict_function
 
   def _initialize_weights(self, cloned_model):
@@ -1825,6 +1869,7 @@
     self._session.close()
 
 
+# pylint: disable=bad-continuation
 def _validate_shapes(model):
   """Validate that all layers in `model` have constant shape."""
   for layer in model.layers:
@@ -1852,10 +1897,13 @@
 Input shape: %(input_shape)s
 Output shape: %(output_shape)s
   """ % {
-      'layer': layer,
-      'input_shape': layer.input_shape,
-      'output_shape': layer.output_shape
-      })
+          'layer': layer,
+          'input_shape': layer.input_shape,
+          'output_shape': layer.output_shape
+          })
+
+
+# pylint: enable=bad-continuation
 
 
 @experimental
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index a423aea..170977d 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -30,7 +30,6 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import tf_logging as logging
 
 
 @contextlib.contextmanager
@@ -258,7 +257,6 @@
       collections = [ops.GraphKeys.GLOBAL_VARIABLES]
     kwargs["collections"] = []
 
-    logging.info("Constructing replicated variable %s", name)
     variables = []
     index = {}
     for i in range(num_replicas):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 1e21cc5..0f9f7cd 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -652,13 +652,31 @@
       # TODO(phawkins): consider removing this code. It will
       # be less confusing to clients if they knowingly choose to use resource
       # variables.
+      # Partitioned variables is not supported (b/112311320).
+      def custom_getter(getter, name, *args, **kwargs):
+        """Variables on TPU have a few restrictions."""
+        partitioner = kwargs["partitioner"]
+        if partitioner is not None:
+          kwargs["partitioner"] = None
+          logging.warning(
+              "Partitioned variables are not supported on TPU. Got "
+              "`partitioner` that is {} for variable {}. "
+              "Setting `partitioner` to `None`."
+              .format(partitioner, name))
+        return getter(name, *args, **kwargs)
+
       vscope = variable_scope.get_variable_scope()
+
       saved_use_resource = vscope.use_resource
+      saved_custom_getter = vscope.custom_getter
+
       vscope.set_use_resource(True)
+      vscope.set_custom_getter(custom_getter)
 
       outputs = computation(*computation_inputs)
 
       vscope.set_use_resource(saved_use_resource)
+      vscope.set_custom_getter(saved_custom_getter)
 
     # If the computation returns `None`, make it an empty tuple.
     if outputs is None:
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ad3dce1..d4951b1 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -63,7 +63,7 @@
   }
   CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
   RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
-  string key(std::move(parsed.FullKey().ToString()));
+  string key(parsed.FullKey());
   string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
 
   Device* dst_dev;
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5c314f3..8f32bc2 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -168,6 +168,7 @@
     "example/example.proto",
     "example/feature.proto",
     "framework/allocation_description.proto",
+    "framework/api_def.proto",
     "framework/attr_value.proto",
     "framework/cost_graph.proto",
     "framework/device_attributes.proto",
@@ -177,9 +178,9 @@
     "framework/iterator.proto",
     "framework/kernel_def.proto",
     "framework/log_memory.proto",
+    "framework/model.proto",
     "framework/node_def.proto",
     "framework/op_def.proto",
-    "framework/api_def.proto",
     "framework/reader_base.proto",
     "framework/remote_fused_graph_execute_info.proto",
     "framework/resource_handle.proto",
@@ -299,6 +300,7 @@
     name = "platform_base_hdrs",
     srcs = [
         "platform/byte_order.h",
+        "platform/cord.h",
         "platform/env_time.h",
         "platform/logging.h",
         "platform/macros.h",
@@ -695,15 +697,32 @@
     visibility = ["//visibility:public"],
     deps = [
         ":lib_internal",
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:optional",
     ],
 )
 
 cc_library(
+    name = "feature_util",
+    srcs = ["example/feature_util.cc"],
+    hdrs = [
+        "example/feature_util.h",
+        "platform/types.h",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":core_stringpiece",
+        ":platform_protobuf",
+        ":protos_all_cc",
+    ],
+)
+
+cc_library(
     name = "abi",
     srcs = ["platform/abi.cc"],
     hdrs = ["platform/abi.h"],
+    deps = [":platform_base"],
 )
 
 cc_library(
@@ -823,6 +842,7 @@
         "framework/log_memory.h",
         "framework/lookup_interface.h",
         "framework/memory_types.h",
+        "framework/model.h",
         "framework/node_def_builder.h",
         "framework/node_def_util.h",
         "framework/numeric_op.h",
@@ -858,7 +878,6 @@
         "util/bcast.h",
         "util/cuda_kernel_helper.h",
         "util/device_name_utils.h",
-        "util/env_var.h",
         "util/events_writer.h",
         "util/example_proto_fast_parsing.h",
         "util/example_proto_helper.h",
@@ -1338,6 +1357,7 @@
         "//tensorflow/core/kernels:mkl_relu_op",
         "//tensorflow/core/kernels:mkl_reshape_op",
         "//tensorflow/core/kernels:mkl_softmax_op",
+        "//tensorflow/core/kernels:mkl_transpose_op",
         "//tensorflow/core/kernels:mkl_tfconv_op",
         "//tensorflow/core/kernels:mkl_aggregate_ops",
     ]) + if_cuda([
@@ -2039,6 +2059,7 @@
     "platform/snappy.h",
     "platform/tensor_coding.h",
     "platform/tracing.h",
+    "util/env_var.h",
 ]
 
 # Replicated for lib_internal and lib_internal_impl.
@@ -2078,6 +2099,7 @@
             "platform/*.cc",
             "platform/profile_utils/**/*.cc",
             "framework/resource_handle.cc",
+            "util/env_var.cc",
         ],
         exclude = [
             "**/*test*",
@@ -2433,7 +2455,6 @@
     "framework/unique_tensor_references.h",
     "framework/variant.h",
     "util/command_line_flags.h",
-    "util/env_var.h",
     "util/equal_graph_def.h",
     "util/presized_cuckoo_map.h",
     "util/tensor_slice_set.h",
@@ -2509,6 +2530,7 @@
             "util/memmapped_file_system_writer.*",
             "util/stats_calculator.*",
             "util/version_info.cc",
+            "util/env_var.cc",
         ],
     ) + select({
         "//tensorflow:windows": [],
@@ -3220,7 +3242,6 @@
         "lib/gtl/edit_distance_test.cc",
         "lib/gtl/flatmap_test.cc",
         "lib/gtl/flatset_test.cc",
-        "lib/gtl/inlined_vector_test.cc",
         "lib/gtl/int_type_test.cc",
         "lib/gtl/iterator_range_test.cc",
         "lib/gtl/manual_constructor_test.cc",
@@ -3712,6 +3733,7 @@
         ":core_cpu_internal",
         ":framework",
         ":framework_internal",
+        ":lib",
         ":test",
         ":test_main",
         ":testlib",
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000..cdaeb50
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+  graph_op_name: "BoostedTreesBucketize"
+  visibility: HIDDEN
+  in_arg {
+    name: "float_values"
+    description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+  }
+  in_arg {
+    name: "bucket_boundaries"
+    description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+  }
+  out_arg {
+    name: "buckets"
+    description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+inferred int; number of features.
+END
+  }
+  summary: "Bucketize each feature based on bucket boundaries."
+  description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000..20da129
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+  graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+  visibility: HIDDEN
+  in_arg {
+    name: "quantile_stream_resource_handle"
+    description: <<END
+resource; Handle to quantile stream resource.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+float; The required approximation error of the stream resource.
+END
+  }
+  in_arg {
+    name: "num_streams"
+    description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+  }
+  attr {
+    name: "max_elements"
+    description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+  }
+  summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000..ca111af
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+  graph_op_name: "BoostedTreesMakeQuantileSummaries"
+  visibility: HIDDEN
+  in_arg {
+    name: "float_values"
+    description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+  }
+  in_arg {
+    name: "example_weights"
+    description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+  }
+  in_arg {
+    name: "epsilon"
+    description: <<END
+float; The required maximum approximation error.
+END
+  }
+  out_arg {
+    name: "summaries"
+    description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+  }
+  summary: "Makes the summary of quantiles for the batch."
+  description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000..bbeecbf
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+  graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+  visibility: HIDDEN
+  in_arg {
+    name: "quantile_stream_resource_handle"
+    description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+  }
+  in_arg {
+    name: "summaries"
+    description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+  }
+  summary: "Add the quantile summaries to each quantile stream resource."
+  description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000..2fd94ef
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+  graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+  visibility: HIDDEN
+  in_arg {
+    name: "quantile_stream_resource_handle"
+    description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+  }
+  in_arg {
+    name: "num_buckets",
+    description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+  }
+  attr {
+    name: "generate_quantiles"
+    description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+  }
+  summary: "Flush the summaries for a quantile stream resource."
+  description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000..20667280
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+  graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+  visibility: HIDDEN
+  in_arg {
+    name: "quantile_stream_resource_handle"
+    description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+  }
+  out_arg {
+    name: "bucket_boundaries"
+    description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+  }
+  attr {
+    name: "num_features"
+    description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+  }
+  summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+  description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000..cb7786c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+  graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+  visibility: HIDDEN
+  summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
index e39213c..4408007 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
@@ -11,7 +11,8 @@
     name: "record_defaults"
     description: <<END
 One tensor per column of the input record, with either a
-scalar default value for that column or empty if the column is required.
+scalar default value for that column or an empty vector if the column is
+required.
 END
   }
   out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000..758eeb9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+  graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+  visibility: HIDDEN
+  in_arg {
+    name: "quantile_stream_resource_handle"
+    description: <<END
+resource; The reference to quantile stream resource handle.
+END
+  }
+  out_arg {
+    name: "is_initialized"
+    description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+  }
+  summary: "Checks whether a quantile stream has been initialized."
+  description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
new file mode 100644
index 0000000..171add1
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+  graph_op_name: "ModelDataset"
+  visibility: HIDDEN
+  in_arg {
+    name: "input_dataset"
+    description: <<END
+A variant tensor representing the input dataset.
+END
+  }
+  summary: "Identity transformation that models performance."
+  description: <<END
+Identity transformation that models performance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000..27bc401
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+  graph_op_name: "ParallelInterleaveDatasetV2"
+  visibility: HIDDEN
+  attr {
+    name: "f"
+    description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+  }
+  summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
index 8cef243..30fd97a 100644
--- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -9,7 +9,7 @@
   in_arg {
     name: "pattern"
     description: <<END
-A 1-D string tensor of the regular expression to match the input.
+A scalar string tensor containing the regular expression to match the input.
 END
   }
   out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 35f55fe..d33a36c 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
 first dimension.  Values should be sorted and can be repeated.
 END
   }
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index 70a07d9..afdc39d 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
 first dimension.  Values should be sorted and can be repeated.
 END
   }
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index b2e3eec..026b5b3 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
 first dimension.  Values should be sorted and can be repeated.
 END
   }
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 7bac02e..a168eed 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
 first dimension.  Values should be sorted and can be repeated.
 END
   }
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index a73306a..876b860 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -3,7 +3,7 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
 first dimension.  Values should be sorted and can be repeated.
 END
   }
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000..6d9d990
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+  graph_op_name: "StaticRegexFullMatch"
+  in_arg {
+    name: "input"
+    description: <<END
+A string tensor of the text to be processed.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+A bool tensor with the same shape as `input`.
+END
+  }
+  attr {
+    name: "pattern"
+    description: "The regular expression to match the input."
+  }
+  summary: "Check if the input matches the regex pattern."
+  description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 8fc1e5c..5246090 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -32,8 +32,10 @@
 If `len` defines a substring that would extend beyond the length of the input
 string, then as many characters as possible are used.
 
-If `pos` is negative or specifies a character index larger than any of the input
-strings, then an `InvalidArgumentError` is thrown.
+A negative `pos` indicates distance within the string backwards from the end.
+
+If `pos` specifies an index which is out of range for any of the input strings,
+then an `InvalidArgumentError` is thrown.
 
 `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
 Op creation.
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 907c6d2..7a60e43 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -3,15 +3,14 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
-END
+A tensor whose shape is a prefix of `data.shape`.END
   }
   out_arg {
     name: "output"
     description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
 END
   }
   summary: "Computes the maximum along segments of a tensor."
@@ -24,13 +23,16 @@
 [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
 Instead of computing the sum over segments, it computes the maximum such that:
 
-\\(output_i = \max_j data_j\\) where max is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+that `segment_ids[j...] == i`.
 
 If the maximum is empty for a given segment ID `i`, it outputs the smallest
 possible value for the specific numeric type,
 `output[i] = numeric_limits<T>::lowest()`.
 
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
+
 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
 <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
 </div>
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 37dd973..7e139dd 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -3,15 +3,15 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
 END
   }
   out_arg {
     name: "output"
     description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
 END
   }
   summary: "Computes the minimum along segments of a tensor."
@@ -24,11 +24,14 @@
 [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
 Instead of computing the sum over segments, it computes the minimum such that:
 
-\\(output_i = \min_j data_j\\) where min is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+that `segment_ids[j...] == i`.
 
 If the minimum is empty for a given segment ID `i`, it outputs the largest
 possible value for the specific numeric type,
 `output[i] = numeric_limits<T>::max()`.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index efbc023..9c8ea3b 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -3,15 +3,15 @@
   in_arg {
     name: "segment_ids"
     description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
 END
   }
   out_arg {
     name: "output"
     description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
 END
   }
   summary: "Computes the product along segments of a tensor."
@@ -25,9 +25,12 @@
 Instead of computing the sum over segments, it computes the product of all
 entries belonging to a segment such that:
 
-\\(output_i = \prod_j data_j\\) where the product is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+`j...` such that `segment_ids[j...] == i`.
 
 If there is no entry for a given segment ID `i`, it outputs 1.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
 END
 }
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index a887495..7e5d926 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -21,7 +21,7 @@
 for an explanation of segments.
 
 Computes a tensor such that
-\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
 that `segment_ids[j...] == i`.  Unlike `SegmentSum`, `segment_ids`
 need not be sorted and need not cover all values in the full
 range of valid values.
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3bf0532..84c6285 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -596,7 +596,7 @@
     region_offset += region.memory_size();
   }
 
-  return std::string(rendered, resolution);
+  return string(rendered, resolution);
 }
 
 void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index f8cb854..cf3d1f0 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -358,7 +358,7 @@
 
 #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION)         \
   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
-      Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+      Tensor, DIRECTION, WrappedTensorDeviceCopy)
 
 REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
 REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index eb38820..b4d8e28 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1228,7 +1228,7 @@
       }
     };
 
-    optimizer.Optimize(lib, options_.env, device, &iter->second,
+    optimizer.Optimize(lib, options_.env, device, &partition_graph,
                        /*shape_map=*/nullptr);
 
     // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 3f2355e..65e816c 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1255,7 +1255,7 @@
   ASSERT_TRUE(s.ok());
   ASSERT_EQ(1, outputs.size());
 
-  ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+  const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
   Tensor string_handle(DT_STRING, {});
   string_handle.flat<string>().setConstant(resource_handle.name());
 
@@ -1308,7 +1308,7 @@
   ASSERT_TRUE(s.ok());
   ASSERT_EQ(1, outputs.size());
 
-  ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+  const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
   Tensor string_handle(DT_STRING, {});
   string_handle.flat<string>().setConstant(resource_handle.name());
 
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 879a794..263467a 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -56,6 +56,7 @@
       log_device_placement_(opts.config.log_device_placement()),
       num_active_steps_(0),
       async_default_(async),
+      log_memory_(LogMemory::IsEnabled()),
       env_(opts.env),
       use_send_tensor_rpc_(false) {
   if (device_mgr_owned) {
@@ -65,13 +66,9 @@
     local_unowned_device_manager_ = device_mgr;
   }
   InitDeviceMapAndAsync();
-  if (opts.config.inter_op_parallelism_threads() > 0) {
-    runner_ = [this](std::function<void()> closure) {
-      this->thread_pool_->Schedule(closure);
-    };
-  } else {
-    runner_ = [](std::function<void()> closure) { closure(); };
-  }
+  runner_ = [this](std::function<void()> closure) {
+    this->thread_pool_->Schedule(closure);
+  };
 }
 
 void EagerContext::InitDeviceMapAndAsync() {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index eb6eb0d..5ed6057 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -33,6 +33,7 @@
 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
 #include "tensorflow/core/distributed_runtime/server_lib.h"
 #endif
+#include "tensorflow/core/framework/log_memory.h"
 #include "tensorflow/core/framework/rendezvous.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/core/threadpool.h"
@@ -141,6 +142,7 @@
   void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
 
   bool LogDevicePlacement() { return log_device_placement_; }
+  bool LogMemory() { return log_memory_; }
 
   Rendezvous* GetRendezvous() { return rendezvous_; }
 
@@ -261,6 +263,8 @@
   std::unordered_map<std::thread::id, bool> thread_local_async_
       GUARDED_BY(async_map_mu_);
 
+  const bool log_memory_;
+
   Env* const env_;
 
 #ifndef __ANDROID__
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 5b3a64b..1da1326 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -296,7 +296,7 @@
       LOG(INFO) << "Executing op " << ndef.op() << " in device "
                 << device->name();
     }
-    kernel = new KernelAndDevice(ctx->GetRendezvous());
+    kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
     auto* flr = ctx->func_lib(device);
 
     if (flr == nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3d61ff4..83d8425 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -32,21 +32,6 @@
 namespace tensorflow {
 
 // static
-Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
-                               KernelAndDevice* out) {
-  OpKernel* k = nullptr;
-  Status s = CreateOpKernel(device->device_type().c_str(), device,
-                            device->GetAllocator(AllocatorAttributes()),
-                            nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
-  out->device_ = device;
-  out->kernel_.reset(k);
-  out->flib_ = nullptr;
-  out->runner_ = nullptr;
-  out->default_runner_ = [](std::function<void()> f) { f(); };
-  return s;
-}
-
-// static
 Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
                              std::function<void(std::function<void()>)>* runner,
                              KernelAndDevice* out) {
@@ -95,6 +80,7 @@
   params.slice_reader_cache = &slice_reader_cache_;
   params.rendezvous = rendez_;
   params.cancellation_manager = &cm_;
+  params.log_memory = log_memory_;
   if (stats != nullptr) {
     params.track_allocations = true;
   }
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 0ef419c..04151a1 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -52,12 +52,12 @@
   static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
                      std::function<void(std::function<void()>)>* runner,
                      KernelAndDevice* out);
-  // TODO(ashankar): Remove this
-  static Status InitOp(Device* device, const NodeDef& ndef,
-                       KernelAndDevice* out);
 
-  KernelAndDevice(tensorflow::Rendezvous* rendez)
-      : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
+  KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
+      : device_(nullptr),
+        flib_(nullptr),
+        rendez_(rendez),
+        log_memory_(log_memory) {}
 
   // TODO(ashankar): Handle list-valued inputs.
   Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@@ -87,6 +87,7 @@
   DataTypeVector output_dtypes_;
   std::function<void(std::function<void()>)>* runner_;
   std::function<void(std::function<void()>)> default_runner_;
+  const bool log_memory_;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 6abe98f..da280b2 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -104,7 +104,7 @@
                    .NumInputs(2)
                    .BuildNodeDef());
   TestEnv env;
-  KernelAndDevice k(nullptr);
+  KernelAndDevice k(nullptr, false);
   tensorflow::testing::StartTiming();
   for (int i = 0; i < iters; ++i) {
     TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
@@ -127,7 +127,7 @@
                    .NumInputs(inputs.size())
                    .BuildNodeDef());
   TestEnv env;
-  KernelAndDevice kernel(nullptr);
+  KernelAndDevice kernel(nullptr, false);
   TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
                                     nullptr, &kernel));
   tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 46bb8d9..1c9b697 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -615,11 +615,14 @@
   std::unordered_set<const Node*> nodes;
   for (auto n : g->nodes()) {
     // NOTE(mrry): "_Retval" nodes are stateful, and so will be added
-    // to the seed set of `nodes`.
+    // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
+    // specifically exclude them as seeds, to avoid unconditionally executing
+    // unused argument nodes (e.g. in a function like `lambda x, y: y`).
     // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
     // still needed. It would be preferable to prune entire loops and/or
     // conditionals if they are not used in the graph.
-    if (n->IsControlFlow() || n->op_def().is_stateful()) {
+    if (n->IsControlFlow() ||
+        (n->op_def().is_stateful() && n->type_string() != kArgOp)) {
       nodes.insert(n);
     }
   }
@@ -925,29 +928,18 @@
   }
   DCHECK(run_opts.runner != nullptr);
 
-  Executor::Args* exec_args = new Executor::Args;
+  Executor::Args exec_args;
   // Inherit the step_id from the caller.
-  exec_args->step_id = run_opts.step_id;
-  exec_args->rendezvous = run_opts.rendezvous;
-  exec_args->stats_collector = run_opts.stats_collector;
-  exec_args->cancellation_manager = run_opts.cancellation_manager;
-  exec_args->collective_executor = run_opts.collective_executor;
-  exec_args->step_container = run_opts.step_container;
-  exec_args->runner = *run_opts.runner;
-  exec_args->call_frame = frame;
+  exec_args.step_id = run_opts.step_id;
+  exec_args.rendezvous = run_opts.rendezvous;
+  exec_args.stats_collector = run_opts.stats_collector;
+  exec_args.cancellation_manager = run_opts.cancellation_manager;
+  exec_args.collective_executor = run_opts.collective_executor;
+  exec_args.step_container = run_opts.step_container;
+  exec_args.runner = *run_opts.runner;
+  exec_args.call_frame = frame;
 
-  item->exec->RunAsync(
-      // Executor args
-      *exec_args,
-      // Done callback.
-      std::bind(
-          [item, frame, exec_args](DoneCallback done,
-                                   // Start unbound arguments.
-                                   const Status& status) {
-            delete exec_args;
-            done(status);
-          },
-          std::move(done), std::placeholders::_1));
+  item->exec->RunAsync(exec_args, std::move(done));
 }
 
 bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480..7bab9be 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@
       // Name
       "SquareAndAddOneWithStatefulNodes",
       // Args
-      {"x: int32"},
+      {"x: int32", "y: float32"},
       // Return values
-      {"y: int32"},
+      {"z: int32"},
       // Attrs
       {},
       // Nodes
@@ -822,12 +822,13 @@
         "RandomUniform",
         {"shape"},
         {{"T", T}, {"dtype", DT_FLOAT}}},
-       // y = Add<T>(a, o)
-       {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+       // z = Add<T>(a, o)
+       {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
   Init({stateful_func});
 
   auto x = test::AsTensor<int32>({1, 2, 3, 4});
-  Tensor y;
+  auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+  Tensor z;
 
   FunctionLibraryRuntime::Handle handle;
   TF_CHECK_OK(
@@ -837,18 +838,19 @@
   StepStatsCollector stats_collector(&stats);
   FunctionLibraryRuntime::Options opts;
   opts.stats_collector = &stats_collector;
-  TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+  TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
   TF_CHECK_OK(flr0_->ReleaseHandle(handle));
 
   TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
-                                {x}, {&y}));
-  test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+                                {x, y}, {&z}));
+  test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
 
   stats_collector.FinalizeAndSwap(&stats);
 
-  // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+  // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+  // execute.
   std::set<string> expected_node_names(
-      {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+      {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
   std::set<string> executed_node_names;
   for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
     executed_node_names.insert(node_stats.node_name());
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 0a1797f..f9aef3a 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -56,7 +56,7 @@
     }
 
     mutex_lock l(mu_);
-    string edge_name = std::string(parsed.edge_name);
+    string edge_name(parsed.edge_name);
     if (table_.count(edge_name) > 0) {
       return errors::Internal("Send of an already sent tensor");
     }
@@ -69,7 +69,7 @@
     Tensor tensor;
     Status status = Status::OK();
     {
-      string key = std::string(parsed.edge_name);
+      string key(parsed.edge_name);
       mutex_lock l(mu_);
       if (table_.count(key) <= 0) {
         status = errors::Internal("Did not find key ", key);
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 6b76e7e..df9c3a6 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -24,9 +24,11 @@
 #include <cstdlib>
 #include "tensorflow/core/common_runtime/bfc_allocator.h"
 #include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator_registry.h"
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
 
 #ifndef INTEL_MKL_DNN_ONLY
 #include "i_malloc.h"
@@ -48,6 +50,125 @@
   void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
 };
 
+// CPU allocator that handles small-size allocations by calling
+// suballocator directly. Mostly, it is just a wrapper around a suballocator
+// (that calls malloc and free directly) with support for bookkeeping.
+class MklSmallSizeAllocator : public VisitableAllocator {
+ public:
+  MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
+                        const string& name)
+      : sub_allocator_(sub_allocator), name_(name) {
+    stats_.bytes_limit = total_memory;
+  }
+  ~MklSmallSizeAllocator() override {}
+
+  TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator);
+
+  inline string Name() override { return name_; }
+
+  void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+    void* ptr = sub_allocator_->Alloc(alignment, num_bytes);
+    if (ptr != nullptr) {
+      std::pair<void*, size_t> map_val(ptr, num_bytes);
+      mutex_lock l(mutex_);
+      // Check that insertion in the hash map was successful.
+      CHECK(map_.insert(map_val).second);
+      // Increment statistics for small-size allocations.
+      IncrementStats(num_bytes);
+      // Call alloc visitors.
+      for (const auto& visitor : alloc_visitors_) {
+        visitor(ptr, num_bytes);
+      }
+    }
+    return ptr;
+  }
+
+  void DeallocateRaw(void* ptr) override {
+    if (ptr == nullptr) {
+      LOG(ERROR) << "tried to deallocate nullptr";
+      return;
+    }
+
+    mutex_lock l(mutex_);
+    auto map_iter = map_.find(ptr);
+    if (map_iter != map_.end()) {
+      // Call free visitors.
+      size_t dealloc_bytes = map_iter->second;
+      for (const auto& visitor : free_visitors_) {
+        visitor(ptr, dealloc_bytes);
+      }
+      sub_allocator_->Free(ptr, dealloc_bytes);
+      DecrementStats(dealloc_bytes);
+      map_.erase(map_iter);
+    } else {
+      LOG(ERROR) << "tried to deallocate invalid pointer";
+      return;
+    }
+  }
+
+  inline bool IsSmallSizeAllocation(const void* ptr) const {
+    mutex_lock l(mutex_);
+    return map_.find(ptr) != map_.end();
+  }
+
+  void GetStats(AllocatorStats* stats) override {
+    mutex_lock l(mutex_);
+    *stats = stats_;
+  }
+
+  void ClearStats() override {
+    mutex_lock l(mutex_);
+    stats_.Clear();
+  }
+
+  void AddAllocVisitor(Visitor visitor) override {
+    mutex_lock l(mutex_);
+    alloc_visitors_.push_back(visitor);
+  }
+
+  void AddFreeVisitor(Visitor visitor) override {
+    mutex_lock l(mutex_);
+    free_visitors_.push_back(visitor);
+  }
+
+ private:
+  // Increment statistics for the allocator handling small allocations.
+  inline void IncrementStats(size_t alloc_size)
+      EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+    ++stats_.num_allocs;
+    stats_.bytes_in_use += alloc_size;
+    stats_.max_bytes_in_use =
+        std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+    stats_.max_alloc_size =
+        std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size));
+  }
+
+  // Decrement statistics for the allocator handling small allocations.
+  inline void DecrementStats(size_t dealloc_size)
+      EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+    stats_.bytes_in_use -= dealloc_size;
+  }
+
+  SubAllocator* sub_allocator_;  // Not owned by this class.
+
+  // Mutex for protecting updates to map of allocations.
+  mutable mutex mutex_;
+
+  // Allocator name
+  string name_;
+
+  // Hash map to keep track of "small" allocations
+  // We do not use BFC allocator for small allocations.
+  std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_);
+
+  // Allocator stats for small allocs
+  AllocatorStats stats_ GUARDED_BY(mutex_);
+
+  // Visitors
+  std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
+  std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
+};
+
 /// CPU allocator for MKL that wraps BFC allocator and intercepts
 /// and redirects memory allocation calls from MKL.
 class MklCPUAllocator : public VisitableAllocator {
@@ -62,7 +183,10 @@
 
   MklCPUAllocator() { TF_CHECK_OK(Initialize()); }
 
-  ~MklCPUAllocator() override { delete allocator_; }
+  ~MklCPUAllocator() override {
+    delete small_size_allocator_;
+    delete large_size_allocator_;
+  }
 
   Status Initialize() {
     VLOG(2) << "MklCPUAllocator: In MklCPUAllocator";
@@ -96,8 +220,15 @@
     }
 
     VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
-    allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
-                                  kAllowGrowth, kName);
+
+    sub_allocator_ = new MklSubAllocator();
+
+    // SubAllocator is owned by BFCAllocator, so we do not need to deallocate
+    // it in MklSmallSizeAllocator.
+    small_size_allocator_ =
+        new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName);
+    large_size_allocator_ =
+        new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName);
 #ifndef INTEL_MKL_DNN_ONLY
     // For redirecting all allocations from MKL to this allocator
     // From: http://software.intel.com/en-us/node/528565
@@ -112,23 +243,55 @@
   inline string Name() override { return kName; }
 
   inline void* AllocateRaw(size_t alignment, size_t num_bytes) override {
-    return allocator_->AllocateRaw(alignment, num_bytes);
+    // If the allocation size is less than threshold, call small allocator,
+    // otherwise call large-size allocator (BFC). We found that BFC allocator
+    // does not deliver good performance for small allocations when
+    // inter_op_parallelism_threads is high.
+    return (num_bytes < kSmallAllocationsThreshold)
+               ? small_size_allocator_->AllocateRaw(alignment, num_bytes)
+               : large_size_allocator_->AllocateRaw(alignment, num_bytes);
   }
 
   inline void DeallocateRaw(void* ptr) override {
-    allocator_->DeallocateRaw(ptr);
+    // Check if ptr is for "small" allocation. If it is, then call Free
+    // directly. Otherwise, call BFC to handle free.
+    if (small_size_allocator_->IsSmallSizeAllocation(ptr)) {
+      small_size_allocator_->DeallocateRaw(ptr);
+    } else {
+      large_size_allocator_->DeallocateRaw(ptr);
+    }
   }
 
-  void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); }
+  void GetStats(AllocatorStats* stats) override {
+    AllocatorStats l_stats, s_stats;
+    small_size_allocator_->GetStats(&s_stats);
+    large_size_allocator_->GetStats(&l_stats);
 
-  void ClearStats() override { allocator_->ClearStats(); }
+    // Combine statistics from small-size and large-size allocator.
+    stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs;
+    stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use;
+    stats->max_bytes_in_use =
+        l_stats.max_bytes_in_use + s_stats.max_bytes_in_use;
+
+    // Since small-size allocations go to MklSmallSizeAllocator,
+    // max_alloc_size from large_size_allocator would be the maximum
+    // size allocated by MklCPUAllocator.
+    stats->max_alloc_size = l_stats.max_alloc_size;
+  }
+
+  void ClearStats() override {
+    small_size_allocator_->ClearStats();
+    large_size_allocator_->ClearStats();
+  }
 
   void AddAllocVisitor(Visitor visitor) override {
-    allocator_->AddAllocVisitor(visitor);
+    small_size_allocator_->AddAllocVisitor(visitor);
+    large_size_allocator_->AddAllocVisitor(visitor);
   }
 
   void AddFreeVisitor(Visitor visitor) override {
-    allocator_->AddFreeVisitor(visitor);
+    small_size_allocator_->AddFreeVisitor(visitor);
+    large_size_allocator_->AddFreeVisitor(visitor);
   }
 
  private:
@@ -148,26 +311,33 @@
     Status s = Status(error::Code::UNIMPLEMENTED,
                       "Unimplemented case for hooking MKL function.");
     TF_CHECK_OK(s);  // way to assert with an error message
-    return nullptr; // return a value and make static code analyzers happy
+    return nullptr;  // return a value and make static code analyzers happy
   }
 
   static inline void* ReallocHook(void* ptr, size_t size) {
     Status s = Status(error::Code::UNIMPLEMENTED,
                       "Unimplemented case for hooking MKL function.");
     TF_CHECK_OK(s);  // way to assert with an error message
-    return nullptr; // return a value and make static code analyzers happy
+    return nullptr;  // return a value and make static code analyzers happy
   }
 
-  /// Do we allow growth in BFC Allocator
+  // Do we allow growth in BFC Allocator
   static const bool kAllowGrowth = true;
 
-  /// Name
+  // Name
   static constexpr const char* kName = "mklcpu";
 
-  /// The alignment that we need for the allocations
+  // The alignment that we need for the allocations
   static constexpr const size_t kAlignment = 64;
 
-  VisitableAllocator* allocator_;  // owned by this class
+  VisitableAllocator* large_size_allocator_;     // owned by this class
+  MklSmallSizeAllocator* small_size_allocator_;  // owned by this class.
+
+  SubAllocator* sub_allocator_;  // not owned by this class
+
+  // Size in bytes that defines the upper-bound for "small" allocations.
+  // Any allocation below this threshold is "small" allocation.
+  static constexpr const size_t kSmallAllocationsThreshold = 4096;
 
   // Prevent copying and assignment
   TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 7f3c25d..3b59995 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -254,9 +254,11 @@
                                               old_root_member.device_name,
                                               allow_soft_placement_);
     if (!s.ok()) {
-      return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
-                                     "' and '", y.name(), ": ",
-                                     s.error_message());
+      return errors::InvalidArgument(
+          "Cannot colocate nodes ",
+          errors::FormatColocationNodeForError(x.name()), " and ",
+          errors::FormatColocationNodeForError(y.name()), ": ",
+          s.error_message());
     }
 
     // Ensure that the common root has at least one supported device
@@ -267,8 +269,10 @@
                           old_root_member.supported_device_types);
     if (new_root_member.supported_device_types.empty()) {
       return errors::InvalidArgument(
-          "Cannot colocate nodes '", x.name(), "' and '", y.name(),
-          "' because no device type supports both of those nodes and the "
+          "Cannot colocate nodes ",
+          errors::FormatColocationNodeForError(x.name()), " and ",
+          errors::FormatColocationNodeForError(y.name()),
+          " because no device type supports both of those nodes and the "
           "other nodes colocated with them.",
           DebugInfo(x_root), DebugInfo(y_root));
     }
@@ -376,8 +380,9 @@
           // merged set device is different, so print both.
           return errors::InvalidArgument(
               "Could not satisfy explicit device specification '",
-              node->requested_device(),
-              "' because the node was colocated with a group of nodes that "
+              node->requested_device(), "' because the node ",
+              errors::FormatColocationNodeForError(node->name()),
+              " was colocated with a group of nodes that ",
               "required incompatible device '",
               DeviceNameUtils::ParsedNameToString(
                   members_[node_root].device_name),
@@ -809,10 +814,10 @@
     std::vector<Device*>* devices;
     Status status = colocation_graph.GetDevicesForNode(node, &devices);
     if (!status.ok()) {
-      return AttachDef(errors::InvalidArgument(
-                           "Cannot assign a device for operation ",
-                           RichNodeName(node), ": ", status.error_message()),
-                       *node);
+      return AttachDef(
+          errors::InvalidArgument("Cannot assign a device for operation ",
+                                  node->name(), ": ", status.error_message()),
+          *node);
     }
 
     // Returns the first device in sorted devices list so we will always
@@ -856,10 +861,10 @@
     std::vector<Device*>* devices;
     Status status = colocation_graph.GetDevicesForNode(node, &devices);
     if (!status.ok()) {
-      return AttachDef(errors::InvalidArgument(
-                           "Cannot assign a device for operation ",
-                           RichNodeName(node), ": ", status.error_message()),
-                       *node);
+      return AttachDef(
+          errors::InvalidArgument("Cannot assign a device for operation ",
+                                  node->name(), ": ", status.error_message()),
+          *node);
     }
 
     int assigned_device = -1;
@@ -925,21 +930,4 @@
   }
 }
 
-bool Placer::ClientHandlesErrorFormatting() const {
-  return options_ != nullptr &&
-         options_->config.experimental().client_handles_error_formatting();
-}
-
-// Returns the node name in single quotes. If the client handles formatted
-// errors, appends a formatting tag which the client will reformat into, for
-// example, " (defined at filename:123)".
-// TODO(shikharagarwal): Remove this function once
-// client_handles_error_formatting flag is removed.
-string Placer::RichNodeName(const Node* node) const {
-  if (ClientHandlesErrorFormatting()) {
-    return errors::FormatNodeNameForError(node->name());
-  }
-  return strings::StrCat("'", node->name(), "'");
-}
-
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index cefcdd2..f97ffe7 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -87,8 +87,6 @@
   // placement if the SessionOptions entry in 'options_' requests it.
   void AssignAndLog(int assigned_device, Node* node) const;
   void LogDeviceAssignment(const Node* node) const;
-  bool ClientHandlesErrorFormatting() const;
-  string RichNodeName(const Node* node) const;
 
   Graph* const graph_;              // Not owned.
   const DeviceSet* const devices_;  // Not owned.
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 83d27e2..9b8a95e 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -800,11 +800,11 @@
   }
 
   Status s = Place(&g);
-  EXPECT_TRUE(
-      str_util::StrContains(s.error_message(),
-                            "Cannot colocate nodes 'foo' and 'in' because no "
-                            "device type supports both of those nodes and the "
-                            "other nodes colocated with them"));
+  EXPECT_TRUE(str_util::StrContains(
+      s.error_message(),
+      "Cannot colocate nodes {{colocation_node foo}} and "
+      "{{colocation_node in}} because no device type supports both of those "
+      "nodes and the other nodes colocated with them"));
 }
 
 TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
@@ -867,9 +867,9 @@
   Status s = Place(&g);
   EXPECT_TRUE(str_util::StrContains(
       s.error_message(),
-      "Cannot colocate nodes 'var3' and 'assign3' because no "
-      "device type supports both of those nodes and the other "
-      "nodes colocated with them."));
+      "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node "
+      "assign3}} because no device type supports both of those nodes and the "
+      "other nodes colocated with them."));
 }
 
 TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -1154,35 +1154,12 @@
   }
 
   SessionOptions options;
-  options.config.mutable_experimental()->set_client_handles_error_formatting(
-      true);
   Status s = Place(&g, &options);
   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
   LOG(WARNING) << s.error_message();
-  EXPECT_TRUE(str_util::StrContains(
-      s.error_message(), "Cannot assign a device for operation {{node in}}"));
-}
-
-// Test that the "Cannot assign a device" error message does not contain a
-// format tag when not it shouldn't
-TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
-  Graph g(OpRegistry::Global());
-  {  // Scope for temporary variables used to construct g.
-    GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
-    ops::SourceOp("TestDevice",
-                  b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
-    TF_EXPECT_OK(BuildGraph(b, &g));
-  }
-
-  SessionOptions options;
-  options.config.mutable_experimental()->set_client_handles_error_formatting(
-      false);
-  Status s = Place(&g, &options);
-  EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
-  EXPECT_TRUE(str_util::StrContains(
-      s.error_message(), "Cannot assign a device for operation 'in'"));
-  EXPECT_FALSE(str_util::StrContains(
-      s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+  EXPECT_TRUE(str_util::StrContains(s.error_message(),
+                                    "Cannot assign a device for operation in"));
+  EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
 }
 
 // Test that placement fails when a node requests an explicit device that is not
@@ -1288,8 +1265,9 @@
 
   Status s = Place(&g);
   EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
-  EXPECT_TRUE(str_util::StrContains(
-      s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
+  EXPECT_TRUE(str_util::StrContains(s.error_message(),
+                                    "Cannot colocate nodes {{colocation_node "
+                                    "var}} and {{colocation_node assign}}"));
 }
 
 // Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 10a24ed..fdad8de 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -26,6 +26,7 @@
 
 #include "tensorflow/core/lib/strings/numbers.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
 
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 1e3fed0..43ca3f1 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/platform/mutex.h"
 
 #include "tensorflow/core/util/reffed_status_callback.h"
 
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 65ff356..5b19157 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -70,7 +70,7 @@
     // Save only the tensors in output_names in the session.
     for (const string& name : output_names) {
       TensorId id(ParseTensorName(name));
-      const string& op_name = std::string(id.first);
+      const string op_name(id.first);
       auto it = tensors_.find(op_name);
       if (it != tensors_.end()) {
         // Save the tensor to the session state.
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
index 04d5af9..22650b0 100644
--- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/eigen_thread_pool.h"
 #include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 9c2510e..836cb8e 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -176,7 +176,7 @@
   } else {
     // Convert the captured string into an integer. But first we need to put
     // the digits back in order
-    string ordered_capture = std::string(capture);
+    string ordered_capture(capture);
     std::reverse(ordered_capture.begin(), ordered_capture.end());
     int gpu_id;
     CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -205,7 +205,7 @@
   } else {
     // Convert the captured string into an integer. But first we need to put
     // the digits back in order
-    string ordered_capture = std::string(capture);
+    string ordered_capture(capture);
     std::reverse(ordered_capture.begin(), ordered_capture.end());
     int gpu_id;
     CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -252,7 +252,7 @@
 
   for (auto& itr : per_device_stats) {
     const StringPiece device_name = itr.first;
-    const int gpu_id = ExtractGpuWithoutStream(std::string(device_name));
+    const int gpu_id = ExtractGpuWithoutStream(string(device_name));
     if (gpu_id >= 0) {
       // Reference the gpu hardware stats in addition to the regular stats
       // for this gpu device if they're available.
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
index 39215ef..e1b1630 100644
--- a/tensorflow/core/common_runtime/tracing_device.h
+++ b/tensorflow/core/common_runtime/tracing_device.h
@@ -35,8 +35,11 @@
       : Device(env, attributes) {}
 
   void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+    const tracing::TraceCollector* trace_collector =
+        tracing::GetTraceCollector();
     if (TF_PREDICT_FALSE(
-            tracing::GetTraceCollector() ||
+            (trace_collector &&
+             trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
             tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
       const string& op_name = op_kernel->name();
       tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 38863db..6994dec 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -693,6 +693,7 @@
 mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
 
 bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+  mutex_lock l(bytes_mu);
   if (globalDiskBytesLimit == 0) {
     const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
     if (env_tfdbg_disk_bytes_limit == nullptr ||
@@ -707,7 +708,6 @@
   if (bytes == 0) {
     return true;
   }
-  mutex_lock l(bytes_mu);
   if (diskBytesUsed + bytes < globalDiskBytesLimit) {
     diskBytesUsed += bytes;
     return true;
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index e7142a4..e36e51d 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -199,7 +199,13 @@
 //     to determine if all features within the FeatureList must
 //     have the same size.  The same holds for this FeatureList across multiple
 //     examples.
-//
+//   - For sequence modeling, e.g.:
+//        http://colah.github.io/posts/2015-08-Understanding-LSTMs/
+//        https://github.com/tensorflow/nmt
+//     the feature lists represent a sequence of frames.
+//     In this scenario, all FeatureLists in a SequenceExample have the same
+//     number of Feature messages, so that the ith element in each FeatureList
+//     is part of the ith frame (or time step).
 // Examples of conformant and non-conformant examples' FeatureLists:
 //
 // Conformant FeatureLists:
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c..2a7ee16 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/core/framework/allocator_registry.h"
 #include "tensorflow/core/framework/log_memory.h"
 #include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@
   for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
 }
 
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+  for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+  for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
 // If true, cpu allocator collects more stats.
 static bool cpu_allocator_collect_stats = false;
 // If true, cpu allocator collects full stats.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe..ded120b 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,13 @@
 #include "tensorflow/core/framework/numeric_types.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
+class Variant;
+
 // Attributes for a single allocation call. Different calls to the same
 // allocator could potentially have different allocation attributes.
 struct AllocationAttributes {
@@ -228,13 +229,9 @@
     for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
   }
 
-  virtual void RunVariantCtor(Variant* p, size_t n) {
-    for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
-  }
+  virtual void RunVariantCtor(Variant* p, size_t n);
 
-  virtual void RunVariantDtor(Variant* p, size_t n) {
-    for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
-  }
+  virtual void RunVariantDtor(Variant* p, size_t n);
 
   // TODO(jeff): Maybe provide some interface to give info about
   // current allocation state (total number of bytes available for
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282c..e907c52 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/numa.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a39947..4ffd732 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@
 #include <numeric>
 #include <vector>
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 9ffd8e1..5281c56 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/graph/node_builder.h"
 
 namespace tensorflow {
+namespace data {
 
 namespace {
 
@@ -329,4 +330,5 @@
   }
 }
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 04865a1..4ee6749 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@
 #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
@@ -40,6 +41,13 @@
 
 namespace tensorflow {
 
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class Node;
+
+namespace data {
+
 class DatasetBase;
 class SerializationContext;
 
@@ -66,11 +74,6 @@
   virtual ~IteratorStateWriter() {}
 };
 
-// Forward declarations to avoid introducing a dependency on headers in
-// "tensorflow/core/graph/...".
-class GraphDefBuilder;
-class Node;
-
 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
 class GraphDefBuilderWrapper {
  public:
@@ -222,8 +225,7 @@
     return (str_util::EndsWith(op_def->name(), "Dataset") &&
             op_def->output_arg_size() == 1 &&
             op_def->output_arg(0).type() == DT_VARIANT) ||
-           dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
-               op_def->name());
+           WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
   }
 
   bool HasAttr(const string& op_type_name, const string& attr_name) const;
@@ -290,6 +292,9 @@
 
     // The Allocator to be used to allocate the output of an iterator.
     std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+    // If non-null, identifies the object used for performance modeling.
+    std::shared_ptr<model::Model> model = nullptr;
   };
 
   explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -341,6 +346,10 @@
     return params_.stats_aggregator_getter;
   }
 
+  std::shared_ptr<model::Model> model() { return params_.model; }
+
+  Params params() { return params_; }
+
  private:
   Params params_;
 };
@@ -375,7 +384,11 @@
 // defined below.
 class IteratorBase {
  public:
-  virtual ~IteratorBase() {}
+  virtual ~IteratorBase() {
+    for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+      (*rit)();
+    }
+  }
 
   // Gets the next output from the range that this iterator is traversing.
   //
@@ -409,6 +422,10 @@
   // in the outputs of this iterator.
   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
 
+  // Returns a string that identifies the sequence of iterators leading up to
+  // this iterator.
+  virtual const string& prefix() const = 0;
+
   // Performs initialization that needs to happen outside of a constructor to
   // properly propagate errors.
   virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -448,6 +465,18 @@
                                  IteratorStateReader* reader) {
     return errors::Unimplemented("RestoreInternal");
   }
+
+ private:
+  friend class DatasetBase;  // for access to `AddCleanupFunction`
+
+  // Registers a cleanup function to be called upon object destruction.
+  //
+  // Registered functions are invoked in the reserve order of registration.
+  void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+    cleanup_fns_.push_back(std::move(cleanup_fn));
+  }
+
+  std::vector<std::function<void()>> cleanup_fns_;
 };
 
 // Represents runtime information needed to construct a dataset.
@@ -497,6 +526,27 @@
   Status MakeIterator(IteratorContext* ctx, const string& prefix,
                       std::unique_ptr<IteratorBase>* iterator) const {
     *iterator = MakeIteratorInternal(prefix);
+    if (ctx->model()) {
+      // The prefix might contain an index. We need to strip it to make it
+      // possible for the model to successfully identify the output node.
+      string sanitized_prefix = prefix;
+      if (str_util::EndsWith(prefix, "]")) {
+        sanitized_prefix = prefix.substr(0, prefix.rfind('['));
+      }
+      std::shared_ptr<model::Node> node =
+          ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix);
+      std::vector<string> tokens =
+          str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty());
+      node->set_name(tokens[tokens.size() - 1]);
+      std::shared_ptr<model::Model> model = ctx->model();
+      const string& prefix = (*iterator)->prefix();
+      (*iterator)->AddCleanupFunction([model, node, prefix]() {
+        if (node->output()) {
+          node->output()->remove_input(node);
+        }
+        model->RemoveNode(prefix);
+      });
+    }
     return (*iterator)->Initialize(ctx);
   }
 
@@ -523,6 +573,8 @@
                       IteratorStateWriter* writer) const;
 
  protected:
+  friend class DatasetToGraphOp;  // For access to graph related members.
+
   class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
    public:
     DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -540,8 +592,6 @@
   virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const = 0;
 
-  friend class DatasetToGraphOp;  // For access to graph related members.
-
  private:
   const string name_;
 };
@@ -564,7 +614,7 @@
   ~DatasetBaseIterator() override { params_.dataset->Unref(); }
 
   // The sequence of iterators leading up to this iterator.
-  const string& prefix() const { return params_.prefix; }
+  const string& prefix() const override { return params_.prefix; }
 
   const DataTypeVector& output_dtypes() const override {
     return params_.dataset->output_dtypes();
@@ -577,7 +627,23 @@
   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                  bool* end_of_sequence) final {
     tracing::ScopedActivity activity(params_.prefix);
-    Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+    Status s;
+    if (ctx->model()) {
+      std::shared_ptr<model::Node> node =
+          ctx->model()->LookupNode(params_.prefix);
+      if (node->output()) {
+        node->output()->stop_work();
+      }
+      node->start_work();
+      s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+      node->stop_work();
+      node->add_element();
+      if (node->output()) {
+        node->output()->start_work();
+      }
+    } else {
+      s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+    }
     if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
       s = errors::Internal(
           "Iterator \"", params_.prefix,
@@ -604,6 +670,39 @@
     return strings::StrCat(params_.prefix, ":", name);
   }
 
+  // When performance modeling is enabled, this method sets metadata entry for
+  // the model node corresponding to this iterator.
+  void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
+    if (ctx->model()) {
+      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+      if (node) {
+        node->set_metadata(key, value);
+      }
+    }
+  }
+
+  // When performance modeling is enabled, this method records the fact that
+  // a thread of this iterator has started work.
+  void StartWork(IteratorContext* ctx) {
+    if (ctx->model()) {
+      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+      if (node) {
+        node->start_work();
+      }
+    }
+  }
+
+  // When performance modeling is enabled, this method records the fact that
+  // a thread of this iterator has stopped work.
+  void StopWork(IteratorContext* ctx) {
+    if (ctx->model()) {
+      std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
+      if (node) {
+        node->stop_work();
+      }
+    }
+  }
+
  private:
   BaseParams params_;
 };
@@ -751,6 +850,21 @@
   std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
 };
 
+}  // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::DatasetBase;
+using data::DatasetContext;
+using data::DatasetIterator;
+using data::DatasetOpKernel;
+using data::IteratorBase;
+using data::IteratorContext;
+using data::IteratorStateReader;
+using data::IteratorStateWriter;
+using data::SerializationContext;
+using data::UnaryDatasetOpKernel;
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_
diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
index 3b48999..74bd39c 100644
--- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h
+++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
@@ -16,38 +16,38 @@
 #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
 #define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
 
+#include <unordered_set>
 #include "tensorflow/core/lib/core/status.h"
 
 namespace tensorflow {
-namespace dataset {
+namespace data {
 // Registry for stateful ops that need to be used in dataset functions.
 // See below macro for usage details.
 class WhitelistedStatefulOpRegistry {
  public:
-  Status Add(StringPiece op_name) {
-    op_names_.insert(op_name);
+  Status Add(string op_name) {
+    op_names_.insert(std::move(op_name));
     return Status::OK();
   }
 
-  bool Contains(StringPiece op_name) {
-    return op_names_.find(op_name) != op_names_.end();
-  }
+  bool Contains(const string& op_name) { return op_names_.count(op_name); }
 
   static WhitelistedStatefulOpRegistry* Global() {
-    static WhitelistedStatefulOpRegistry* reg =
-        new WhitelistedStatefulOpRegistry;
+    static auto* reg = new WhitelistedStatefulOpRegistry;
     return reg;
   }
 
  private:
-  WhitelistedStatefulOpRegistry() {}
-  WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+  WhitelistedStatefulOpRegistry() = default;
+  WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
+      delete;
   WhitelistedStatefulOpRegistry operator=(
-      WhitelistedStatefulOpRegistry const& copy);
-  std::set<StringPiece> op_names_;
+      WhitelistedStatefulOpRegistry const& copy) = delete;
+
+  std::unordered_set<string> op_names_;
 };
 
-}  // namespace dataset
+}  // namespace data
 
 // Use this macro to whitelist an op that is marked stateful but needs to be
 // used inside a map_fn in an input pipeline. This is only needed if you wish
@@ -67,10 +67,9 @@
   WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
 #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
   WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
-#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)        \
-  static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED =      \
-      ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
-          name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)   \
+  static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
+      ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 26f3267..d979353d 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,6 +1154,17 @@
   return default_registry_->LookUp(op, op_reg_data);
 }
 
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+  tf_shared_lock l(mu_);
+  int index = 0;
+  string name = strings::StrCat(prefix, index);
+  while (function_defs_.find(name) != function_defs_.end()) {
+    ++index;
+    name = strings::StrCat(prefix, index);
+  }
+  return name;
+}
+
 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
     const NodeDef& ndef) const {
   if (ndef.op() != kGradientOp) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 03296a7..e01eb75 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,6 +358,10 @@
                 const OpRegistrationData** op_reg_data) const override
       LOCKS_EXCLUDED(mu_);
 
+  // Generates new function name with the specified prefix that is unique
+  // across this library.
+  string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
   // Ops created for function arguments bear the name given by `kArgOp`; those
   // created for return values bear the name given by `kRetOp`.
   static constexpr const char* const kArgOp = "_Arg";
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 46b169d..c5a4f66 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -110,6 +110,22 @@
       });
 }
 
+FunctionDef XAddX() {
+  return FDH::Define(
+      // Name
+      "XAddX",
+      // Args
+      {"x: T"},
+      // Return values
+      {"y: T"},
+      // Attr def
+      {"T: {float, double, int32, int64}"},
+      // Nodes
+      {
+          {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
+      });
+}
+
 FunctionDef XTimesTwoInt32() {
   const Tensor kTwo = test::AsScalar<int64>(2);
   return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index 6d6476b..ad61a76 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -63,6 +63,9 @@
 // x:T -> x * 2.
 FunctionDef XTimesTwo();
 
+// x:T -> x + x.
+FunctionDef XAddX();
+
 // x:T -> x * 2, where x is int32.
 FunctionDef XTimesTwoInt32();
 
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000..250b006
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,396 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
+  mutex_lock l(mu_);
+  switch (type_) {
+    case Type::PARALLEL_INTERLEAVE_V2: {
+      for (auto input : inputs_) {
+        input->CollectKnobs(knobs);
+      }
+      int64 processing_time = static_cast<int64>(
+          static_cast<double>(ProcessingTimeLocked() -
+                              inputs_.front()->ProcessingTime()) /
+          static_cast<double>(inputs_.size() - 1));
+      knobs->emplace_back(
+          Node::Knob{this, processing_time, metadata_["parallelism"]});
+      return;
+    }
+    case Type::MAP_AND_BATCH:
+    case Type::PARALLEL_MAP: {
+      for (auto input : inputs_) {
+        input->CollectKnobs(knobs);
+      }
+      knobs->emplace_back(
+          Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
+      return;
+    }
+    case Type::BATCH:
+    case Type::CACHE:
+    case Type::CONCATENATE:
+    case Type::FILTER:
+    case Type::FLAT_MAP:
+    case Type::INTERLEAVE:
+    case Type::MAP:
+    case Type::PADDED_BATCH:
+    case Type::PARALLEL_INTERLEAVE:
+    case Type::PREFETCH:
+    case Type::REPEAT:
+    case Type::SHUFFLE:
+    case Type::SKIP:
+    case Type::TAKE:
+    case Type::ZIP: {
+      for (auto input : inputs_) {
+        input->CollectKnobs(knobs);
+      }
+      return;
+    }
+    default:
+      return;
+  }
+}
+
+int64 Node::ProcessingTimeLocked() {
+  switch (type_) {
+    case Type::BATCH:
+    case Type::MAP_AND_BATCH:
+    case Type::PADDED_BATCH: {
+      int64 batch_size = metadata_["batch_size"];
+      return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+    }
+    case Type::FILTER: {
+      std::shared_ptr<Node> input = inputs_.front();
+      double ratio = static_cast<double>(input->num_elements()) /
+                     static_cast<double>(num_elements_);
+      return NanosPerElementLocked() +
+             static_cast<int64>(ratio *
+                                static_cast<double>(ProcessingTimeForInputs()));
+    }
+    case Type::FLAT_MAP:
+    case Type::INTERLEAVE:
+    case Type::PARALLEL_INTERLEAVE:
+    case Type::PARALLEL_INTERLEAVE_V2: {
+      // TODO(jsimsa): model the first input
+      // TODO(jsimsa): use processing time history as a prior for future inputs
+      if (inputs_.size() <= 1) {
+        return NanosPerElementLocked();
+      }
+      int64 processing_time =
+          ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+      return NanosPerElementLocked() +
+             static_cast<double>(processing_time) /
+                 static_cast<double>(inputs_.size() - 1);
+    }
+    case Type::CACHE:
+    case Type::CONCATENATE:
+    case Type::MAP:
+    case Type::PARALLEL_MAP:
+    case Type::PREFETCH:
+      // TODO(jsimsa): use processing time history as a prior for future inputs
+    case Type::REPEAT:
+    case Type::SHUFFLE:
+    case Type::SKIP:
+    case Type::TAKE:
+    case Type::ZIP: {
+      return NanosPerElementLocked() + ProcessingTimeForInputs();
+    }
+    default:
+      return NanosPerElementLocked();
+  }
+}
+
+int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
+  switch (type_) {
+    case Type::BATCH:
+    case Type::PADDED_BATCH: {
+      double batch_size = metadata_["batch_size"];
+      int64 old_value = (*input_times)[input_times->size() - 1];
+      (*input_times)[input_times->size() - 1] = static_cast<int64>(
+          static_cast<double>(old_value + NanosPerElementLocked()) /
+          batch_size);
+      auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+        (*input_times)[input_times->size() - 1] = old_value;
+      });
+      return NanosPerElementLocked() +
+             batch_size * OutputTimeForInputs(input_times);
+    }
+    case Type::FILTER: {
+      std::shared_ptr<Node> input = inputs_.front();
+      int64 old_value = (*input_times)[input_times->size() - 1];
+      double ratio = static_cast<double>(input->num_elements()) /
+                     static_cast<double>(num_elements_);
+      (*input_times)[input_times->size() - 1] = static_cast<int64>(
+          static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+      auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+        (*input_times)[input_times->size() - 1] = old_value;
+      });
+      return NanosPerElementLocked() +
+             static_cast<int64>(
+                 static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+    }
+    case Type::FLAT_MAP:
+    case Type::INTERLEAVE: {
+      // TODO(jsimsa): model the first input
+      // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+      if (inputs_.size() <= 1) {
+        return NanosPerElementLocked();
+      }
+      int64 delta =
+          static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+                             static_cast<double>(inputs_.size() - 1));
+      (*input_times)[input_times->size() - 1] += delta;
+      auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+        (*input_times)[input_times->size() - 1] -= delta;
+      });
+      int64 output_time = OutputTimeForInputs(input_times) -
+                          inputs_.front()->OutputTime(input_times);
+      return NanosPerElementLocked() +
+             static_cast<double>(output_time) /
+                 static_cast<double>(inputs_.size() - 1);
+    }
+    case Type::MAP_AND_BATCH: {
+      double batch_size = metadata_["batch_size"];
+      double parallelism = metadata_["parallelism"];
+      int64 delta =
+          static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+                             (batch_size * parallelism));
+      input_times->push_back(delta);
+      auto cleanup =
+          gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+      int64 output_time = static_cast<int64>(
+          static_cast<double>(NanosPerElementLocked()) / parallelism +
+          batch_size * OutputTimeForInputs(input_times));
+      return std::max(0LL,
+                      output_time - input_times->at(input_times->size() - 2));
+    }
+    case Type::PARALLEL_INTERLEAVE:
+    case Type::PARALLEL_INTERLEAVE_V2: {
+      // TODO(jsimsa): model the first input
+      if (inputs_.size() <= 1) {
+        return NanosPerElementLocked();
+      }
+      int64 delta =
+          static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+                             static_cast<double>(inputs_.size() - 1));
+      input_times->push_back(delta);
+      auto cleanup =
+          gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+      int64 inputs_output_time = OutputTimeForInputs(input_times) -
+                                 inputs_.front()->OutputTime(input_times);
+      double parallelism = std::min(port::NumSchedulableCPUs(),
+                                    static_cast<int>(metadata_["parallelism"]));
+      int64 output_time =
+          NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+                                      static_cast<double>(inputs_.size() - 1)) /
+                                     parallelism);
+      return std::max(0LL,
+                      output_time - input_times->at(input_times->size() - 2));
+    }
+    case Type::PARALLEL_MAP: {
+      double parallelism = std::min(port::NumSchedulableCPUs(),
+                                    static_cast<int>(metadata_["parallelism"]));
+      int64 delta = static_cast<int64>(
+          static_cast<double>(NanosPerElementLocked()) / parallelism);
+      input_times->push_back(delta);
+      auto cleanup =
+          gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+      int64 output_time =
+          static_cast<double>(NanosPerElementLocked()) / parallelism +
+          OutputTimeForInputs(input_times);
+      return std::max(0LL,
+                      output_time - input_times->at(input_times->size() - 2));
+    }
+    case Type::PREFETCH: {
+      int64 delta = NanosPerElementLocked();
+      input_times->push_back(delta);
+      auto cleanup =
+          gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+      return std::max(0LL, NanosPerElementLocked() +
+                               OutputTimeForInputs(input_times) -
+                               input_times->at(input_times->size() - 2));
+    }
+    case Type::CACHE:
+    case Type::CONCATENATE:
+    case Type::MAP:
+    case Type::REPEAT:
+    case Type::SHUFFLE:
+    case Type::SKIP:
+    case Type::TAKE:
+    case Type::ZIP: {
+      int64 delta = NanosPerElementLocked();
+      (*input_times)[input_times->size() - 1] += delta;
+      auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+        (*input_times)[input_times->size() - 1] -= delta;
+      });
+      return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+    }
+    default:
+      return NanosPerElementLocked();
+  }
+}
+
+Model::Model(const proto::Model& model_proto) {
+  id_counter_ = model_proto.id_counter();
+  std::map<int64, std::shared_ptr<Node>> lookup_table;
+  for (auto node_proto : model_proto.node()) {
+    std::shared_ptr<Node> node(new Node(node_proto));
+    lookup_table[node_proto.id()] = node;
+  }
+  for (auto node_proto : model_proto.node()) {
+    std::shared_ptr<Node> node = lookup_table[node_proto.id()];
+    for (int64 id : node_proto.input()) {
+      node->add_input(lookup_table[id]);
+    }
+    node->set_output(lookup_table[node_proto.output()]);
+  }
+  output_ = lookup_table[model_proto.output()];
+}
+
+std::shared_ptr<Node> Model::AddNode(const string& name,
+                                     const string& output_name) {
+  mutex_lock l(mu_);
+  std::shared_ptr<Node> output;
+  auto it = lookup_table_.find(output_name);
+  if (it != lookup_table_.end()) {
+    output = it->second;
+  }
+  std::shared_ptr<Node> node(new Node(id_counter_++, output));
+  if (!output_) {
+    output_ = node;
+  }
+  if (output) {
+    output->add_input(node);
+  }
+  lookup_table_.insert(std::make_pair(name, node));
+  return node;
+}
+
+std::shared_ptr<Node> Model::LookupNode(const string& name) {
+  tf_shared_lock l(mu_);
+  std::shared_ptr<Node> result;
+  auto it = lookup_table_.find(name);
+  if (it != lookup_table_.end()) {
+    result = it->second;
+  }
+  return result;
+}
+
+void Model::Optimize() {
+  mutex_lock l(mu_);
+  int64 processing_time = ProcessingTime();
+  int64 num_cpus = port::NumSchedulableCPUs();
+  std::vector<Node::Knob> knobs = CollectKnobs();
+  // The optimization algorithm starts by setting all parallelism knobs to 1. It
+  // then repeatedly identifies the knob that, when turned up by 1, decreases
+  // the output time the most. This process is repeated until all knobs reach
+  // the number of schedulable CPUs or the projected output time is less than or
+  // equal to the processing time needed to produce an element divided by the
+  // number of schedulable CPUs.
+  for (auto& knob : knobs) {
+    LOG(INFO) << knob.node->name() << " " << knob.processing_time;
+    knob.value = 1;
+    knob.node->set_metadata("parallelism", knob.value);
+  }
+  while (true) {
+    int64 output_time = OutputTime();
+    bool all_knobs = true;
+    for (auto knob : knobs) {
+      if (knob.value < num_cpus) {
+        all_knobs = false;
+        break;
+      }
+    }
+    if (output_time < processing_time / num_cpus || all_knobs) {
+      break;
+    }
+    int64 best_delta = -1;
+    int best_knob = -1;
+    for (int i = 0; i < knobs.size(); ++i) {
+      if (knobs[i].value == num_cpus) {
+        continue;
+      }
+      knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
+      int64 delta = output_time - OutputTime();
+      if (delta > best_delta) {
+        best_delta = delta;
+        best_knob = i;
+      }
+      knobs[i].node->set_metadata("parallelism", knobs[i].value);
+    }
+    knobs[best_knob].value++;
+    knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
+  }
+  for (auto knob : knobs) {
+    LOG(INFO) << knob.node->name() << " " << knob.value;
+  }
+  LOG(INFO) << "output time: " << OutputTime();
+  LOG(INFO) << "processing time: " << ProcessingTime();
+}
+
+void Model::OutputToFile() {
+  proto::Model model_proto;
+  ToProto(&model_proto);
+  string filename;
+  Env::Default()->LocalTempFilename(&filename);
+  TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
+                                model_proto.SerializeAsString()));
+  LOG(INFO) << filename;
+}
+
+void Model::RemoveNode(const string& prefix) {
+  mutex_lock l(mu_);
+  lookup_table_.erase(prefix);
+}
+
+void Model::ToProto(proto::Model* model_proto) {
+  mutex_lock l(mu_);
+  model_proto->set_id_counter(id_counter_);
+  model_proto->set_output(output_->id());
+  AddNodeToProto(output_, model_proto);
+}
+
+// static
+void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
+                           proto::Model* model_proto) {
+  proto::Node* node_proto = model_proto->add_node();
+  node->ToProto(node_proto);
+  for (const std::shared_ptr<Node>& input : node->inputs()) {
+    AddNodeToProto(input, model_proto);
+  }
+}
+
+std::vector<Node::Knob> Model::CollectKnobs() {
+  std::vector<Node::Knob> knobs;
+  output_->CollectKnobs(&knobs);
+  return knobs;
+}
+
+int64 Model::OutputTime() {
+  std::vector<int64> input_times(1, 0);
+  return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+}  // namespace model
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000..9817290
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,396 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread>  // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/model.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+class Model;
+class Node;
+
+// Abstract representation of a TensorFlow input pipeline node. It collects
+// information about inputs to this node, processing time spent executing the
+// node logic, number of elements produced by the node, various other
+// information (e.g. batch size or execution parallelism).
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting common information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// In addition, `DatasetBaseIterator` provides wrappers that can be used for
+// transformation-specific information collection. The `SetMetadata` wrapper can
+// be used to pass arbitrary metadata to the modeling framework, while the
+// `StartWork` and `StopWork` wrappers should be used to correctly account for
+// processing time of multi-threaded transformation that yield the CPU; such
+// transformations should invoke `StartWork()` when a transformation thread
+// starts executing (e.g. when created or woken up) and `StopWork()` when a
+// transformation thread stops executing (e.g. when returning or waiting).
+//
+// TODO(jsimsa): Create an API to capture the abstract semantics of each
+// tf.data transformation and replace switch-case blocks with inheritance.
+class Node {
+ public:
+  Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
+
+  explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
+    name_ = node_proto.name();
+    type_ = TypeFromName(node_proto.name());
+    processing_time_ = node_proto.processing_time();
+    num_elements_ = node_proto.num_elements();
+    metadata_.insert(node_proto.metadata().begin(),
+                     node_proto.metadata().end());
+  }
+
+  // Records that the node produced an element.
+  void add_element() LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    num_elements_++;
+  }
+
+  // Adds an input.
+  void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    inputs_.push_back(node);
+  }
+
+  // Increments the aggregate processing time by the given delta.
+  void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    processing_time_ += delta;
+  }
+
+  // Returns the unique node ID.
+  int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+  // Returns the node inputs.
+  std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+    tf_shared_lock l(mu_);
+    return inputs_;
+  }
+
+  // Returns the node name.
+  const string& name() LOCKS_EXCLUDED(mu_) {
+    tf_shared_lock l(mu_);
+    return name_;
+  }
+
+  // Returns the number of elements produced by the node.
+  int64 num_elements() LOCKS_EXCLUDED(mu_) {
+    tf_shared_lock l(mu_);
+    return num_elements_;
+  }
+
+  // Returns the node output.
+  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+    tf_shared_lock l(mu_);
+    return output_;
+  }
+
+  // Removes an input.
+  void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    inputs_.remove(input);
+  }
+
+  // Adds the given key-value pair to the node metadata.
+  void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    metadata_[key] = value;
+  }
+
+  // Sets the node name.
+  void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    name_ = name;
+    type_ = TypeFromName(name);
+  }
+
+  // Set the node output.
+  void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    output_ = output;
+  }
+
+  // Records that a node thread has started work.
+  void start_work() LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+  }
+
+  // Records that a node thread has stopped work.
+  void stop_work() LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    auto iter = work_start_.find(std::this_thread::get_id());
+    CHECK(work_start_.end() != iter)
+        << "Encountered a stop event that was not preceded by a start event.";
+    processing_time_ += Env::Default()->NowNanos() - iter->second;
+    work_start_.erase(iter);
+  }
+
+ private:
+  // Represents a performance knob.
+  struct Knob {
+    Node* node;
+    int64 processing_time;
+    int64 value;
+  };
+
+  enum class Type {
+    BATCH = 0,
+    CACHE,
+    CONCATENATE,
+    FILTER,
+    FLAT_MAP,
+    INTERLEAVE,
+    MAP,
+    MAP_AND_BATCH,
+    PADDED_BATCH,
+    PARALLEL_INTERLEAVE,
+    PARALLEL_INTERLEAVE_V2,
+    PARALLEL_MAP,
+    PREFETCH,
+    REPEAT,
+    SHUFFLE,
+    SKIP,
+    TAKE,
+    ZIP,
+    UNKNOWN,
+  };
+
+  // Collects performance knobs in the subtree rooted in this node.
+  void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
+
+  // Returns the per-element processing time spent in this node.
+  int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    return NanosPerElementLocked();
+  }
+
+  int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    if (num_elements_ == 0) {
+      return 0;
+    }
+    return (int64)((double)processing_time_ / (double)num_elements_);
+  }
+
+  // Returns the per-element output time for this node.
+  int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    return OutputTimeLocked(input_times);
+  }
+
+  int64 OutputTimeLocked(std::vector<int64>* input_times)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  int64 OutputTimeForInputs(std::vector<int64>* input_times)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    int64 sum = 0;
+    for (auto input : inputs_) {
+      sum += input->OutputTime(input_times);
+    }
+    return sum;
+  }
+
+  // Returns the per-element processing time spent in the subtree rooted in this
+  // node.
+  int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    return ProcessingTimeLocked();
+  }
+
+  int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Returns the per-element processing time spent in the inputs of this node.
+  int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    int64 sum = 0;
+    for (auto input : inputs_) {
+      sum += input->ProcessingTimeLocked();
+    }
+    return sum;
+  }
+
+  // Serializes the node state into the given proto.
+  void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
+    mutex_lock l(mu_);
+    node_proto->set_id(id_);
+    node_proto->set_name(name_);
+    node_proto->set_num_elements(num_elements_);
+    node_proto->set_processing_time(processing_time_);
+    for (const std::shared_ptr<Node>& input : inputs_) {
+      node_proto->add_input(input->id());
+    }
+    if (output_) {
+      node_proto->set_output(output_->id());
+    }
+    node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
+  }
+
+  Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    if (name_ == "Batch") {
+      return Type::BATCH;
+    }
+    if (str_util::EndsWith(name_, "Cache")) {
+      return Type::CACHE;
+    }
+    if (name_ == "Concatenate") {
+      return Type::CONCATENATE;
+    }
+    if (name_ == "Filter") {
+      return Type::FILTER;
+    }
+    if (name_ == "FlatMap") {
+      return Type::FLAT_MAP;
+    }
+    if (name_ == "Interleave") {
+      return Type::INTERLEAVE;
+    }
+    if (name_ == "Map") {
+      return Type::MAP;
+    }
+    if (name_ == "MapAndBatch") {
+      return Type::MAP_AND_BATCH;
+    }
+    if (name_ == "PaddedBatch") {
+      return Type::PADDED_BATCH;
+    }
+    if (name_ == "ParallelInterleave") {
+      return Type::PARALLEL_INTERLEAVE;
+    }
+    if (name_ == "ParallelInterleaveV2") {
+      return Type::PARALLEL_INTERLEAVE_V2;
+    }
+    if (name_ == "ParallelMap") {
+      return Type::PARALLEL_MAP;
+    }
+    if (name_ == "Prefetch") {
+      return Type::PREFETCH;
+    }
+    if (str_util::EndsWith(name_, "Repeat")) {
+      return Type::REPEAT;
+    }
+    if (name_ == "Shuffle") {
+      return Type::SHUFFLE;
+    }
+    if (str_util::EndsWith(name_, "Skip")) {
+      return Type::SKIP;
+    }
+    if (str_util::EndsWith(name_, "Take")) {
+      return Type::TAKE;
+    }
+    if (name_ == "Zip") {
+      return Type::ZIP;
+    }
+    return Type::UNKNOWN;
+  }
+
+  mutex mu_;
+  const int64 id_;
+  Type type_ GUARDED_BY(mu_);
+  string name_ GUARDED_BY(mu_);
+  int64 processing_time_ GUARDED_BY(mu_) = 0;
+  int64 num_elements_ GUARDED_BY(mu_) = 0;
+  std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+  std::map<string, int64> metadata_ GUARDED_BY(mu_);
+  std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+  std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+
+  friend class Model;
+};
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of performance knobs.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+//
+// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
+// into the input pipeline.
+class Model {
+ public:
+  Model() = default;
+  explicit Model(const proto::Model& model_proto);
+
+  ~Model() {}
+
+  // Returns the model output node.
+  std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+    tf_shared_lock l(mu_);
+    return output_;
+  }
+
+  // Adds a node with the given name and given output (identified by name).
+  std::shared_ptr<Node> AddNode(const string& name, const string& output_name)
+      LOCKS_EXCLUDED(mu_);
+
+  // Looks up the node using the given name.
+  std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+  // Runs optimization.
+  void Optimize() LOCKS_EXCLUDED(mu_);
+
+  // Outputs the state of a model to a file.
+  //
+  // TODO(jsimsa): Remove this method once the optimization loop is closed.
+  void OutputToFile() LOCKS_EXCLUDED(mu_);
+
+  // Removes the node identified by the given name.
+  void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
+
+  // Serializes the model state to the given proto.
+  void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
+
+ private:
+  static void AddNodeToProto(const std::shared_ptr<Node>& node,
+                             proto::Model* model_proto);
+
+  std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  mutex mu_;
+  int64 id_counter_ GUARDED_BY(mu_) = 1;
+  std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+  std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+}  // namespace model
+}  // namespace data
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/model.proto b/tensorflow/core/framework/model.proto
new file mode 100644
index 0000000..2600000
--- /dev/null
+++ b/tensorflow/core/framework/model.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package tensorflow.data.model.proto;
+option cc_enable_arenas = true;
+
+message Model {
+  // Counter used for generating new node IDs.
+  int64 id_counter = 1;
+  // Nodes of this model.
+  repeated Node node = 2;
+  // The ID of the output node.
+  int64 output = 3;
+};
+
+message Node {
+  // The node ID.
+  int64 id = 1;
+  // The node name.
+  string name = 2;
+  // Input node IDs.
+  repeated int64 input = 3;
+  // Output node ID.
+  int64 output = 4;
+  // Number of elements produced by the node.
+  int64 num_elements = 5;
+  // The CPU time spent by running threads of this node.
+  int64 processing_time = 6;
+  // Key-value store for node metadata (e.g. batch size or parallelism).
+  map<string, int32> metadata = 7;
+};
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 0a19861..ebdaaec 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -271,7 +271,7 @@
                          "]");
 }
 
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
   return ctx->input(input).flat<ResourceHandle>()(0);
 }
 
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index f8a587c..d58deaa 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -79,7 +79,7 @@
   virtual string DebugString() = 0;
 
   // Returns memory used by this resource.
-  virtual int64 MemoryUsed() const { return 0; };
+  virtual int64 MemoryUsed() const { return 0; }
 };
 
 // Container used for per-step resources.
@@ -234,7 +234,7 @@
                                          const string& name);
 
 // Returns a resource handle from a numbered op input.
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
 Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
                        ResourceHandle* handle);
 
@@ -348,6 +348,8 @@
 
   void Compute(OpKernelContext* ctx) override;
 
+  bool IsExpensive() override { return false; }
+
  private:
   string container_;
   string name_;
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 4a18efc..af53ed0 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -25,6 +25,8 @@
 
 class Summary;
 
+namespace data {
+
 // A `StatsAggregator` accumulates statistics incrementally. A
 // `StatsAggregator` can accumulate multiple different statistics, distinguished
 // by a string name.
@@ -87,6 +89,7 @@
   const std::shared_ptr<StatsAggregator> stats_aggregator_;
 };
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5..696fd27 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@
 class AllocationDescription;
 class Allocator;
 class OpKernelContext;
+class Tensor;
 class TensorBuffer;
 class TensorCApi;
 class TensorDescription;
 class TensorProto;
-class VariantTensorData;
+
 namespace batch_util {
 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
 Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c..9a78cdc 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/lib/math/math_util.h"
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 4bda8f9..a7cf600 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
 
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 
 #include <vector>
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 15b1add..2e96b05 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,7 +30,6 @@
 #include "tensorflow/core/framework/numeric_types.h"
 #include "tensorflow/core/framework/resource_handle.h"
 #include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@
 
 namespace tensorflow {
 
+class Variant;
+
 // MemoryType is used to describe whether input or output Tensors of
 // an OpKernel should reside in "Host memory" (e.g., CPU memory) or
 // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a50780..d43e3c72 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@
 
 namespace tensorflow {
 
-bool Variant::TryDecode(Variant* out) const {
-  const VariantTensorDataProto* p = get<VariantTensorDataProto>();
-  if (p == nullptr) return false;
-  VariantTensorData data(*p);
-  return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+  if (!is_empty()) {
+    return value_->Decode(std::move(data));
+  }
+  return true;
 }
 
 template <>
@@ -54,13 +54,12 @@
 template <>
 void EncodeVariant(const VariantTensorDataProto& value,
                    VariantTensorData* data) {
-  data->FromProto(value);
+  data->FromConstProto(value);
 }
 
 template <>
-bool DecodeVariant(const VariantTensorData& data,
-                   VariantTensorDataProto* value) {
-  data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+  data->ToProto(value);
   return true;
 }
 
@@ -70,8 +69,8 @@
 }
 
 template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
-  return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+  return value->ParseFromString(*buf);
 }
 
 void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@
     if (variant_array[i].is_empty()) {
       variant_array[i] = VariantTensorDataProto();
     }
+    // TODO(ebrevdo): Replace with StringPiece?  Any way to make this a
+    // zero-copy operation that keeps a reference to the data in d?
     string str(d->Data(sizes[i]), sizes[i]);
-    if (!variant_array[i].Decode(str)) return false;
+    if (!variant_array[i].Decode(std::move(str))) return false;
     if (!DecodeUnaryVariant(&variant_array[i])) {
       LOG(ERROR) << "Could not decode variant with type_name: \""
                  << variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index 5273280..10eabbc8 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -23,7 +23,6 @@
 #include <unordered_map>
 #include <utility>
 
-#include "tensorflow/core/framework/tensor.pb.h"  // TODO(b/62899350): Remove
 #include "tensorflow/core/framework/type_index.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/lib/core/status.h"
@@ -38,18 +37,20 @@
 template <typename T>
 string DebugStringVariant(const T& value);
 
+// Allows for specializations of Variant Decoding.  `data` may be modified in
+// the process of decoding to `value`.
+template <typename T>
+bool DecodeVariant(VariantTensorData* data, T* value);
+
+template <typename T>
+bool DecodeVariant(string* buf, T* value);
+
 template <typename T>
 void EncodeVariant(const T& value, VariantTensorData* data);
 
 template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
-
-template <typename T>
 void EncodeVariant(const T& value, string* buf);
 
-template <typename T>
-bool DecodeVariant(const string& buf, T* value);
-
 // This is an implementation of a type-erased container that can store an
 // object of any type. The implementation is very similar to std::any, but has
 // restrictions on the types of objects that can be stored, and eschews some of
@@ -67,7 +68,7 @@
 //
 //   string TypeName() const;
 //   void Encode(VariantTensorData* data) const;
-//   void Decode(const VariantTensorData& data);
+//   void Decode(VariantTensorData data);
 //
 // Simple POD types can elide the Encode/Decode functions, they are provided by
 // helper methods.
@@ -121,7 +122,7 @@
 //   x.Encode(&serialized_f);
 //
 //   Variant y = Foo(); // default constructed Foo.
-//   y.Decode(&serialized_f);
+//   y.Decode(std::move(serialized_f));
 //   EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
 //
 //
@@ -145,10 +146,6 @@
 //   EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName());  // Looks like Foo.
 //   EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
 //             y_type_unknown.TypeId());
-//   // Decode and get y_type_unknown; compare to value in x.
-//   Foo f_decoded;
-//   EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-//   EXPECT_EQ(f_decoded, f);
 //
 class Variant {
  public:
@@ -241,12 +238,7 @@
   }
 
   // Deserialize `data` and update the stored object.
-  bool Decode(const VariantTensorData& data) {
-    if (!is_empty()) {
-      return value_->Decode(data);
-    }
-    return true;
-  }
+  bool Decode(VariantTensorData data);
 
   // Helper methods to directly serialize/deserialize from strings.
   void Encode(string* buf) const {
@@ -254,31 +246,13 @@
       value_->Encode(buf);
     }
   }
-  bool Decode(const string& buf) {
+  bool Decode(string buf) {
     if (!is_empty()) {
-      return value_->Decode(buf);
+      return value_->Decode(std::move(buf));
     }
     return true;
   }
 
-  template <typename T>
-  bool MaybeDecodeAndCopy(T* out) const {
-    const T* ret = get<T>();
-    if (ret != nullptr) {
-      *out = std::move(*ret);
-      return true;
-    };
-    Variant decoded = T();
-    if (!TryDecode(&decoded)) return false;
-    T* decoded_ret = decoded.get<T>();
-    CHECK_NOTNULL(decoded_ret);
-    *out = std::move(*decoded_ret);
-    return true;
-  }
-
- private:
-  bool TryDecode(Variant* out) const;
-
  private:
   struct in_place_t {};
   static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@
     virtual string TypeName() const = 0;
     virtual string DebugString() const = 0;
     virtual void Encode(VariantTensorData* data) const = 0;
-    virtual bool Decode(const VariantTensorData& data) = 0;
+    virtual bool Decode(VariantTensorData data) = 0;
     virtual void Encode(string* buf) const = 0;
-    virtual bool Decode(const string& data) = 0;
+    virtual bool Decode(string data) = 0;
   };
 
   template <typename T>
@@ -325,15 +299,13 @@
       EncodeVariant(value, data);
     }
 
-    bool Decode(const VariantTensorData& data) override {
-      return DecodeVariant(data, &value);
+    bool Decode(VariantTensorData data) override {
+      return DecodeVariant(&data, &value);
     }
 
     void Encode(string* buf) const override { EncodeVariant(value, buf); }
 
-    bool Decode(const string& buf) override {
-      return DecodeVariant(buf, &value);
-    }
+    bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
 
     T value;
   };
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index f155aa4..5e08e5a 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@
 
 // Specialization for POD type
 template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
                        TypeResolver<T, true /* is_pod */, false /* Tensor */,
                                     false /* protobuf */>,
                        T* value) {
@@ -90,7 +91,7 @@
 
 // Specialization for tensorflow::Tensor
 template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
                        TypeResolver<T, false /* is_pod */, true /* Tensor */,
                                     false /* protobuf */>,
                        T* value) {
@@ -100,7 +101,7 @@
 
 // Specialization for protobuf
 template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
                                     true /* protobuf */>,
                        T* value) {
@@ -111,11 +112,11 @@
 
 // Specialization for other types
 template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
                        TypeResolver<T, false /* is_pod */, false /* Tensor */,
                                     false /* protobuf */>,
                        T* value) {
-  return value->Decode(data);
+  return value->Decode(std::move(data));
 }
 
 template <typename C, typename = void>
@@ -224,8 +225,8 @@
 }
 
 template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
-  return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+  return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
 }
 
 template <typename T>
@@ -238,26 +239,31 @@
 }
 
 template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
   VariantTensorData data;
-  if (!data.ParseFromString(buf)) return false;
-  if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+  if (!data.ParseFromString(*buf)) return false;
+  if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+    return false;
+  }
   return true;
 }
 
 // Specializations for VariantTensorDataProto
 template <>
 string TypeNameVariant(const VariantTensorDataProto& value);
+
 template <>
 void EncodeVariant(const VariantTensorDataProto& value,
                    VariantTensorData* data);
+
 template <>
-bool DecodeVariant(const VariantTensorData& data,
-                   VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
 template <>
 void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
 template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
 
 // Encodes an array of Variant objects in to the given StringListEncoder.
 // `variant_array` is assumed to point to an array of `n` Variant objects.
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd..daa744e 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@
 
 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
     StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
-    "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+    StoredTensorValue::CopyCPUToGPU);
 
 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
     StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
-    "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+    StoredTensorValue::CopyGPUToCPU);
 
 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
     StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
-    "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+    StoredTensorValue::CopyGPUToGPU);
 
 REGISTER_OP("CreateTestVariant")
     .Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1..ef5b240 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@
 }
 
 UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
-    StringPiece type_name) {
-  auto found = shape_fns.find(type_name);
+    const TypeIndex& type_index) {
+  auto found = shape_fns.find(type_index);
   if (found == shape_fns.end()) return nullptr;
   return &found->second;
 }
 
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
                                              const VariantShapeFn& shape_fn) {
-  CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
-  VariantShapeFn* existing = GetShapeFn(type_name);
+  VariantShapeFn* existing = GetShapeFn(type_index);
   CHECK_EQ(existing, nullptr)
-      << "Unary VariantShapeFn for type_name: " << type_name
-      << " already registered";
-  shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
-      GetPersistentStringPiece(type_name), shape_fn));
+      << "Unary VariantShapeFn for type_index: "
+      << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+  shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
 }
 
 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@
   CHECK_EQ(variant_tensor.dims(), 0);
   const Variant& v = variant_tensor.scalar<Variant>()();
   UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
-      UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+      UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
   if (shape_fn == nullptr) {
     return errors::Internal(
-        "No unary variant shape function found for Variant type_name: ",
-        v.TypeName());
+        "No unary variant shape function found for Variant type_index: ",
+        port::MaybeAbiDemangle(v.TypeId().name()));
   }
   return (*shape_fn)(v, shape);
 }
@@ -79,7 +77,7 @@
 }  // namespace
 
 #define REGISTER_VARIANT_SHAPE_TYPE(T) \
-  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
 
 // No encode/shape registered for std::complex<> and Eigen::half
 // objects yet.
@@ -143,25 +141,24 @@
 
 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
 UnaryVariantOpRegistry::GetDeviceCopyFn(
-    const VariantDeviceCopyDirection direction, StringPiece type_name) {
-  auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+    const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+  auto found = device_copy_fns.find(std::make_pair(direction, type_index));
   if (found == device_copy_fns.end()) return nullptr;
   return &found->second;
 }
 
 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
-    const VariantDeviceCopyDirection direction, const string& type_name,
+    const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
     const AsyncVariantDeviceCopyFn& device_copy_fn) {
-  CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
-  AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+  AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
   CHECK_EQ(existing, nullptr)
       << "UnaryVariantDeviceCopy for direction: " << direction
-      << " and type_name: " << type_name << " already registered";
+      << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+      << " already registered";
   device_copy_fns.insert(
-      std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
-                AsyncVariantDeviceCopyFn>(
-          std::make_pair(direction, GetPersistentStringPiece(type_name)),
-          device_copy_fn));
+      std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+                AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+                                          device_copy_fn));
 }
 
 Status VariantDeviceCopy(
@@ -170,35 +167,34 @@
     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
-                                                        from.TypeName());
+                                                        from.TypeId());
   if (device_copy_fn == nullptr) {
     return errors::Internal(
         "No unary variant device copy function found for direction: ",
-        direction, " and Variant type_name: ", from.TypeName());
+        direction, " and Variant type_index: ",
+        port::MaybeAbiDemangle(from.TypeId().name()));
   }
   return (*device_copy_fn)(from, to, copy_fn);
 }
 
 // Special casing UnaryOpFn per op and per device.
 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
-    VariantUnaryOp op, StringPiece device, StringPiece type_name) {
-  auto found = unary_op_fns.find({op, device, type_name});
+    VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+  auto found = unary_op_fns.find({op, device, type_index});
   if (found == unary_op_fns.end()) return nullptr;
   return &found->second;
 }
 
 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
-    VariantUnaryOp op, const string& device, const string& type_name,
+    VariantUnaryOp op, const string& device, const TypeIndex& type_index,
     const VariantUnaryOpFn& unary_op_fn) {
-  CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
-  VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+  VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
   CHECK_EQ(existing, nullptr)
-      << "Unary VariantUnaryOpFn for type_name: " << type_name
+      << "Unary VariantUnaryOpFn for type_index: "
+      << port::MaybeAbiDemangle(type_index.name())
       << " already registered for device type: " << device;
   unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
-      {op, GetPersistentStringPiece(device),
-       GetPersistentStringPiece(type_name)},
-      unary_op_fn));
+      {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
 }
 
 namespace {
@@ -212,7 +208,7 @@
 
 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
-                                           DEVICE_CPU, T, TF_STR(T),    \
+                                           DEVICE_CPU, T,               \
                                            ZerosLikeVariantPrimitiveType<T>);
 
 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@
 // Special casing BinaryOpFn per op and per device.
 UnaryVariantOpRegistry::VariantBinaryOpFn*
 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
-                                      StringPiece type_name) {
-  auto found = binary_op_fns.find({op, device, type_name});
+                                      const TypeIndex& type_index) {
+  auto found = binary_op_fns.find({op, device, type_index});
   if (found == binary_op_fns.end()) return nullptr;
   return &found->second;
 }
 
 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
-    VariantBinaryOp op, const string& device, const string& type_name,
+    VariantBinaryOp op, const string& device, const TypeIndex& type_index,
     const VariantBinaryOpFn& add_fn) {
-  CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
-  VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+  VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
   CHECK_EQ(existing, nullptr)
-      << "Unary VariantBinaryOpFn for type_name: " << type_name
+      << "Unary VariantBinaryOpFn for type_index: "
+      << port::MaybeAbiDemangle(type_index.name())
       << " already registered for device type: " << device;
   binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
-      {op, GetPersistentStringPiece(device),
-       GetPersistentStringPiece(type_name)},
-      add_fn));
+      {op, GetPersistentStringPiece(device), type_index}, add_fn));
 }
 
 namespace {
@@ -257,8 +251,7 @@
 
 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
-                                            T, TF_STR(T),                      \
-                                            AddVariantPrimitiveType<T>);
+                                            T, AddVariantPrimitiveType<T>);
 
 // No add registered for std::complex<> or Eigen::half objects yet.
 REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index e6a2665..7eb37e8 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -22,10 +22,14 @@
 
 #define EIGEN_USE_THREADS
 
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
 #include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
 
 namespace tensorflow {
 
@@ -90,10 +94,11 @@
       AsyncVariantDeviceCopyFn;
 
   // Add a shape lookup function to the registry.
-  void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+  void RegisterShapeFn(const TypeIndex& type_index,
+                       const VariantShapeFn& shape_fn);
 
-  // Returns nullptr if no shape function was found for the given TypeName.
-  VariantShapeFn* GetShapeFn(StringPiece type_name);
+  // Returns nullptr if no shape function was found for the given TypeIndex.
+  VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
 
   // Add a decode function to the registry.
   void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@
 
   // Add a copy-to-GPU function to the registry.
   void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
-                            const string& type_name,
+                            const TypeIndex& type_index,
                             const AsyncVariantDeviceCopyFn& device_copy_fn);
 
   // Returns nullptr if no copy function was found for the given
   // TypeName and direction.
   AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
-      const VariantDeviceCopyDirection direction, StringPiece type_name);
+      const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
 
   // Add a unary op function to the registry.
   void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
-                         const string& type_name,
+                         const TypeIndex& type_index,
                          const VariantUnaryOpFn& unary_op_fn);
 
   // Returns nullptr if no unary op function was found for the given
   // op, device, and TypeName.
   VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
-                                 StringPiece type_name);
+                                 const TypeIndex& type_index);
 
   // Add a binary op function to the registry.
   void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
-                          const string& type_name,
+                          const TypeIndex& type_index,
                           const VariantBinaryOpFn& add_fn);
 
   // Returns nullptr if no binary op function was found for the given
   // op, device and TypeName.
   VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
-                                   StringPiece type_name);
+                                   const TypeIndex& type_index);
 
   // Get a pointer to a global UnaryVariantOpRegistry object
   static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@
   static std::unordered_set<string>* PersistentStringStorage();
 
  private:
-  std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
-  std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
-      decode_fns;
+  struct TypeIndexHash {
+    std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+  };
+
+  gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+  gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
 
   // Map std::pair<Direction, type_name> to function.
   struct PairHash {
     template <typename Direction>
-    std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+    std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
       // The hash of an enum is just its value as a std::size_t.
       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
-      ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+      ret = Hash64Combine(ret, std::get<1>(x).hash_code());
       return ret;
     }
-    StringPieceHasher sp_hasher_;
   };
 
-  std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
-                     AsyncVariantDeviceCopyFn, PairHash>
+  gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+               AsyncVariantDeviceCopyFn, PairHash>
       device_copy_fns;
 
   // Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@
   // and references therein
   template <typename Op>
   struct FuncTuple {
-    FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
-        : op_type_(op), device_(dev), typename_(tname){};
+    FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+        : op_type_(op), device_(dev), type_index_(type_index) {}
     Op op_type_;
-    StringPiece device_, typename_;
+    StringPiece device_;
+    TypeIndex type_index_;
   };
   // friend declaration for operator==
   // needed for clang
@@ -184,11 +192,11 @@
   struct TupleHash {
     template <typename Op>
     std::size_t operator()(
-        const std::tuple<Op, StringPiece, StringPiece>& x) const {
+        const std::tuple<Op, StringPiece, TypeIndex>& x) const {
       // The hash of an enum is just its value as a std::size_t.
       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
-      ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+      ret = Hash64Combine(ret, std::get<2>(x).hash_code());
       return ret;
     }
 
@@ -197,14 +205,14 @@
       // The hash of an enum is just its value as a std::size_t.
       std::size_t ret = static_cast<std::size_t>(x.op_type_);
       ret = Hash64Combine(ret, sp_hasher_(x.device_));
-      ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+      ret = Hash64Combine(ret, x.type_index_.hash_code());
       return ret;
     }
     StringPieceHasher sp_hasher_;
   };
-  std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+  gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
       unary_op_fns;
-  std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+  gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
       binary_op_fns;
 
   // Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@
 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
                        const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
   return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
-         (lhs.typename_ == rhs.typename_);
+         (lhs.type_index_ == rhs.type_index_);
 }
 // Gets a TensorShape from a Tensor containing a scalar Variant.
 // Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@
                       Variant* v_out) {
   const string& device = DeviceName<Device>::value;
   UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
-      UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+      UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
   if (unary_op_fn == nullptr) {
     return errors::Internal(
         "No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@
 template <typename Device>
 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
                         const Variant& a, const Variant& b, Variant* out) {
-  if (a.TypeName() != b.TypeName()) {
+  if (a.TypeId() != b.TypeId()) {
     return errors::Internal(
         "BianryOpVariants: Variants a and b have different "
-        "type names: '",
+        "type ids.  Type names: '",
         a.TypeName(), "' vs. '", b.TypeName(), "'");
   }
   const string& device = DeviceName<Device>::value;
   UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
-      UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+      UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
   if (binary_op_fn == nullptr) {
     return errors::Internal(
         "No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@
  public:
   typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
 
-  UnaryVariantShapeRegistration(const string& type_name,
+  UnaryVariantShapeRegistration(const TypeIndex& type_index,
                                 const LocalVariantShapeFn& shape_fn) {
+    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
     UnaryVariantOpRegistry::Global()->RegisterShapeFn(
-        type_name,
-        [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+        type_index,
+        [type_index_name, shape_fn](const Variant& v,
+                                    TensorShape* s) -> Status {
           const T* t = v.get<T>();
           if (t == nullptr) {
             return errors::Internal(
-                "VariantShapeFn: Could not access object, type_name: ",
-                type_name);
+                "VariantShapeFn: Could not access object, type_index: ",
+                type_index_name);
           }
           return shape_fn(*t, s);
         });
@@ -355,11 +365,11 @@
             return false;
           }
           Variant decoded = T();
-          VariantTensorData data(*t);
-          if (!decoded.Decode(data)) {
+          VariantTensorData data(std::move(*t));
+          if (!decoded.Decode(std::move(data))) {
             return false;
           }
-          *v = std::move(decoded);
+          std::swap(decoded, *v);
           return true;
         });
   }
@@ -372,11 +382,12 @@
                                UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
       LocalVariantDeviceCopyFn;
   UnaryVariantDeviceCopyRegistration(
-      const VariantDeviceCopyDirection direction, const string& type_name,
+      const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
       const LocalVariantDeviceCopyFn& device_copy_fn) {
+    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
     UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
-        direction, type_name,
-        [type_name, device_copy_fn](
+        direction, type_index,
+        [type_index_name, device_copy_fn](
             const Variant& from, Variant* to,
             UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
                 device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@
           *to = T();
           if (from.get<T>() == nullptr) {
             return errors::Internal(
-                "VariantCopyToGPUFn: Could not access object, type_name: ",
-                type_name);
+                "VariantCopyToGPUFn: Could not access object, type_index: ",
+                type_index_name);
           }
           const T& t = *from.get<T>();
           T* t_out = to->get<T>();
@@ -401,18 +412,19 @@
 
  public:
   UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
-                                  const string& type_name,
+                                  const TypeIndex& type_index,
                                   const LocalVariantUnaryOpFn& unary_op_fn) {
+    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
     UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
-        op, device, type_name,
-        [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
-                                 Variant* v_out) -> Status {
+        op, device, type_index,
+        [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+                                       Variant* v_out) -> Status {
           DCHECK_NE(v_out, nullptr);
           *v_out = T();
           if (v.get<T>() == nullptr) {
             return errors::Internal(
-                "VariantUnaryOpFn: Could not access object, type_name: ",
-                type_name);
+                "VariantUnaryOpFn: Could not access object, type_index: ",
+                type_index_name);
           }
           const T& t = *v.get<T>();
           T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@
 
  public:
   UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
-                                   const string& type_name,
+                                   const TypeIndex& type_index,
                                    const LocalVariantBinaryOpFn& binary_op_fn) {
+    const string type_index_name = port::MaybeAbiDemangle(type_index.name());
     UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
-        op, device, type_name,
-        [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
-                                  const Variant& b, Variant* out) -> Status {
+        op, device, type_index,
+        [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+                                        const Variant& b,
+                                        Variant* out) -> Status {
           DCHECK_NE(out, nullptr);
           *out = T();
           if (a.get<T>() == nullptr) {
             return errors::Internal(
-                "VariantBinaryOpFn: Could not access object 'a', type_name: ",
-                type_name);
+                "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+                type_index_name);
           }
           if (b.get<T>() == nullptr) {
             return errors::Internal(
-                "VariantBinaryOpFn: Could not access object 'b', type_name: ",
-                type_name);
+                "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+                type_index_name);
           }
           const T& t_a = *a.get<T>();
           const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@
 
 // Register a unary shape variant function with the signature:
 //    Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function)    \
-  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
-                                                    shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(             \
+      __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
 
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
-                                                          shape_function)    \
-  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+                                                          shape_function)     \
+  REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
 
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name,          \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index,         \
                                                    shape_function)             \
   static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
-      register_unary_variant_op_shape_registration_fn_##ctr(type_name,         \
+      register_unary_variant_op_shape_registration_fn_##ctr(type_index,        \
                                                             shape_function)
 
 // Register a unary decode variant function for the given type.
@@ -519,63 +533,63 @@
 // ****** NOTE ******
 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
 // ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(       \
-    T, direction, type_name, device_copy_fn)                        \
-  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
-      __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction,   \
+                                                             device_copy_fn) \
+  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER(          \
+      __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
 
 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
-    ctr, T, direction, type_name, device_copy_fn)                         \
+    ctr, T, direction, type_index, device_copy_fn)                        \
   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
-      ctr, T, direction, type_name, device_copy_fn)
+      ctr, T, direction, type_index, device_copy_fn)
 
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(             \
-    ctr, T, direction, type_name, device_copy_fn)                              \
-  static variant_op_registry_fn_registration::                                 \
-      UnaryVariantDeviceCopyRegistration<T>                                    \
-          register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
-                                                         device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+    ctr, T, direction, type_index, device_copy_fn)                 \
+  static variant_op_registry_fn_registration::                     \
+      UnaryVariantDeviceCopyRegistration<T>                        \
+          register_unary_variant_op_device_copy_fn_##ctr(          \
+              direction, type_index, device_copy_fn)
 
 // Register a unary unary_op variant function with the signature:
 //    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
 // for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
-                                                 unary_op_function)        \
-  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(                    \
-      __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T,     \
+                                                 unary_op_function) \
+  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(             \
+      __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
 
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(                  \
-    ctr, op, device, T, type_name, unary_op_function)                          \
-  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
-                                                unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(       \
+    ctr, op, device, T, type_index, unary_op_function)              \
+  REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+                                                type_index, unary_op_function)
 
 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                         \
-    ctr, op, device, T, type_name, unary_op_function)                          \
+    ctr, op, device, T, type_index, unary_op_function)                         \
   static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
       T>                                                                       \
-      register_unary_variant_op_decoder_fn_##ctr(op, device, type_name,        \
+      register_unary_variant_op_decoder_fn_##ctr(op, device, type_index,       \
                                                  unary_op_function)
 
 // Register a binary_op variant function with the signature:
 //    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
 // for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
-                                                  binary_op_function)       \
-  REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(                    \
-      __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T,      \
+                                                  binary_op_function) \
+  REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(              \
+      __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
 
 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
-    ctr, op, device, T, type_name, binary_op_function)         \
+    ctr, op, device, T, type_index, binary_op_function)        \
   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
-      ctr, op, device, T, type_name, binary_op_function)
+      ctr, op, device, T, type_index, binary_op_function)
 
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                     \
-    ctr, op, device, T, type_name, binary_op_function)                      \
-  static variant_op_registry_fn_registration::                              \
-      UnaryVariantBinaryOpRegistration<T>                                   \
-          register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                      \
+    ctr, op, device, T, type_index, binary_op_function)                      \
+  static variant_op_registry_fn_registration::                               \
+      UnaryVariantBinaryOpRegistration<T>                                    \
+          register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
                                                      binary_op_function)
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62..b2443e8 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@
   int value;
 };
 
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
-                                      VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
 
 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
 
 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
     VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
-    "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+    VariantValue::CPUToGPUCopyFn);
 
 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
                                          DEVICE_CPU, VariantValue,
-                                         "TEST VariantValue",
                                          VariantValue::CPUZerosLikeFn);
 
 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
                                          DEVICE_GPU, VariantValue,
-                                         "TEST VariantValue",
                                          VariantValue::GPUZerosLikeFn);
 
 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
-                                          VariantValue, "TEST VariantValue",
-                                          VariantValue::CPUAddFn);
+                                          VariantValue, VariantValue::CPUAddFn);
 
 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
-                                          VariantValue, "TEST VariantValue",
-                                          VariantValue::GPUAddFn);
+                                          VariantValue, VariantValue::GPUAddFn);
 
 }  // namespace
 
 TEST(VariantOpShapeRegistryTest, TestBasic) {
-  EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+  class Blah {};
+  EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
             nullptr);
 
-  auto* shape_fn =
-      UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+  auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+      MakeTypeIndex<VariantValue>());
   EXPECT_NE(shape_fn, nullptr);
   TensorShape shape;
 
@@ -142,10 +138,11 @@
 TEST(VariantOpShapeRegistryTest, TestDuplicate) {
   UnaryVariantOpRegistry registry;
   UnaryVariantOpRegistry::VariantShapeFn f;
-  string kTypeName = "fjfjfj";
-  registry.RegisterShapeFn(kTypeName, f);
-  EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
-               "fjfjfj already registered");
+  class FjFjFj {};
+  const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+  registry.RegisterShapeFn(kTypeIndex, f);
+  EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+               "FjFjFj already registered");
 }
 
 TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@
 
 TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
   // No registered copy fn for GPU<->GPU.
-  EXPECT_EQ(
-      UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
-          VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
-      nullptr);
+  EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+                VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+                MakeTypeIndex<VariantValue>()),
+            nullptr);
 
   auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
-      VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+      VariantDeviceCopyDirection::HOST_TO_DEVICE,
+      MakeTypeIndex<VariantValue>());
   EXPECT_NE(copy_to_gpu_fn, nullptr);
 
   VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@
 TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
   UnaryVariantOpRegistry registry;
   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
-  string kTypeName = "fjfjfj";
+  class FjFjFj {};
+  const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
   registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
-                                kTypeName, f);
+                                kTypeIndex, f);
   EXPECT_DEATH(registry.RegisterDeviceCopyFn(
-                   VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
-               "fjfjfj already registered");
+                   VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+               "FjFjFj already registered");
 }
 
 TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+  class Blah {};
   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
-                ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+                ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
             nullptr);
 
   VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@
 
 #if GOOGLE_CUDA
 TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+  class Blah {};
   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
-                ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+                ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
             nullptr);
 
   VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@
 TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
   UnaryVariantOpRegistry registry;
   UnaryVariantOpRegistry::VariantUnaryOpFn f;
-  string kTypeName = "fjfjfj";
+  class FjFjFj {};
+  const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
 
-  registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
-                             f);
+  registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+                             kTypeIndex, f);
   EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
-                                          DEVICE_CPU, kTypeName, f),
-               "fjfjfj already registered");
+                                          DEVICE_CPU, kTypeIndex, f),
+               "FjFjFj already registered");
 
-  registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
-                             f);
+  registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+                             kTypeIndex, f);
   EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
-                                          DEVICE_GPU, kTypeName, f),
-               "fjfjfj already registered");
+                                          DEVICE_GPU, kTypeIndex, f),
+               "FjFjFj already registered");
 }
 
 TEST(VariantOpAddRegistryTest, TestBasicCPU) {
-  return;
+  class Blah {};
   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
-                ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+                ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
             nullptr);
 
   VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@
 
 #if GOOGLE_CUDA
 TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+  class Blah {};
   EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
-                ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+                ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
             nullptr);
 
   VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@
 TEST(VariantOpAddRegistryTest, TestDuplicate) {
   UnaryVariantOpRegistry registry;
   UnaryVariantOpRegistry::VariantBinaryOpFn f;
-  string kTypeName = "fjfjfj";
+  class FjFjFj {};
+  const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
 
-  registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+  registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
   EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
-                                           kTypeName, f),
-               "fjfjfj already registered");
+                                           kTypeIndex, f),
+               "FjFjFj already registered");
 
-  registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+  registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
   EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
-                                           kTypeName, f),
-               "fjfjfj already registered");
+                                           kTypeIndex, f),
+               "FjFjFj already registered");
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc..3e67e4a 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@
 
 VariantTensorData::VariantTensorData() {}
 
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
-  FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+  FromProto(std::move(proto));
 }
 
 VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@
   }
 }
 
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+  // TODO(ebrevdo): Do this lazily.
+  set_type_name(proto.type_name());
+  set_metadata(proto.metadata());
+  for (const auto& tensor : proto.tensors()) {
+    Tensor tmp;
+    if (!tmp.FromProto(tensor)) return false;
+    tensors_.push_back(tmp);
+  }
+  return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
   set_type_name(proto.type_name());
   set_metadata(proto.metadata());
   for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@
   return proto.SerializeToString(buf);
 }
 
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
   VariantTensorDataProto proto;
   const bool status = proto.ParseFromString(s);
-  if (status) FromProto(proto);
+  if (status) FromProto(std::move(proto));
   return status;
 }
 
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 7500e77..8a240ee 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -19,13 +19,13 @@
 #include <algorithm>
 #include <vector>
 
+#include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 
 class VariantTensorDataProto;
-class Tensor;
 
 // The serialization format for Variant objects. Objects with references to
 // other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@
 class VariantTensorData {
  public:
   VariantTensorData();
-  VariantTensorData(const VariantTensorDataProto& proto);
+  VariantTensorData(VariantTensorDataProto proto);
   ~VariantTensorData();
 
   // Name of the type of objects being serialized.
@@ -68,12 +68,14 @@
 
   // Conversion to and from VariantTensorDataProto
   void ToProto(VariantTensorDataProto* proto) const;
-  bool FromProto(const VariantTensorDataProto& proto);
+  // This allows optimizations via std::move.
+  bool FromProto(VariantTensorDataProto proto);
+  bool FromConstProto(const VariantTensorDataProto& proto);
 
   // Serialization via VariantTensorDataProto
   string SerializeAsString() const;
   bool SerializeToString(string* buf);
-  bool ParseFromString(const string& s);
+  bool ParseFromString(string s);
 
   string DebugString() const;
 
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47..08d09de 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@
 struct TensorList {
   void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
 
-  bool Decode(const VariantTensorData& data) {
-    vec = data.tensors_;
+  bool Decode(VariantTensorData data) {
+    vec = std::move(data.tensors_);
     return true;
   }
 
@@ -186,7 +186,7 @@
   x.Encode(&serialized);
 
   Variant y = TensorList();
-  y.Decode(serialized);
+  y.Decode(std::move(serialized));
 
   const TensorList& decoded_vec = *y.get<TensorList>();
   for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@
   EXPECT_EQ(y_unknown.DebugString(),
             strings::StrCat(
                 "Variant<type: TensorList value: ", data.DebugString(), ">"));
-
-  TensorList unknown_decoded_vec;
-  EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
-  for (int i = 0; i < 4; ++i) {
-    EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
-  }
-  for (int i = 0; i < 4; ++i) {
-    EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
-  }
 }
 
 TEST(VariantTest, VariantArray) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index ee10194..7399613 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,12 +1042,12 @@
   }
 
   if (processed < node_defs_.size()) {
-    LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+    LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
                  << " NODES IN A CYCLE";
     for (int64 i = 0; i < node_defs_.size(); i++) {
       if (pending_count_[i] != 0) {
         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
-                     << "WITH PENDING COUNT = " << pending_count_[i];
+                     << " WITH PENDING COUNT = " << pending_count_[i];
       }
     }
     return errors::InvalidArgument(node_defs_.size() - processed,
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f..0a38aa1 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@
   return ret;
 }
 
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+                  .Input(in)
+                  .Attr("message", message)
+                  .Finalize(g, &ret));
+  return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+                  .Attr("T", type)
+                  .Attr("index", index)
+                  .Finalize(g, &ret));
+  return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+  Node* ret;
+  TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+                  .Input(in)
+                  .Attr("index", index)
+                  .Finalize(g, &ret));
+  return ret;
+}
+
 void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
 
 }  // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 8585b35..bd0284d 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -209,6 +209,15 @@
 // Add a DiagPart node in "g".
 Node* DiagPart(Graph* g, Node* in, DataType type);
 
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
 }  // end namespace graph
 }  // end namespace test
 }  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 6710ff9..d273edd 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -345,6 +345,56 @@
   }
 }
 
+bool IsShapeFullyDefinedIntegerVectorOrScalar(
+    InferenceContext* ic, const ShapeHandle& shape,
+    const ShapeHandle& tensor_as_shape, const DataType& dtype) {
+  if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
+      !ic->FullyDefined(tensor_as_shape) ||
+      (dtype != DT_INT32 && dtype != DT_INT64)) {
+    return false;
+  }
+  return true;
+}
+
+// Returned tensor's shape is like `shape`, and its values and dtype are from
+// `tensor_as_shape` and `dtype`.
+TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
+                                     const ShapeHandle& shape,
+                                     const ShapeHandle& tensor_as_shape,
+                                     const DataType& dtype) {
+  TensorProto tensor_proto;
+  tensor_proto.set_dtype(dtype);
+  auto* shape_proto = tensor_proto.mutable_tensor_shape();
+  if (ic->Rank(shape) == 1) {
+    shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
+  }
+  // For a scalar tensor, tensor_shape field will be left empty; no dim.
+  for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
+    int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
+    if (dtype == DT_INT32) {
+      tensor_proto.add_int_val(value);
+    } else {
+      tensor_proto.add_int64_val(value);
+    }
+  }
+  return tensor_proto;
+}
+
+// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
+// and dtype = `dtype`.
+NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
+                                  const ShapeHandle& shape,
+                                  const ShapeHandle& tensor_as_shape,
+                                  const DataType& dtype) {
+  NodeDef const_node;
+  const_node.set_name("const_from_shape");
+  const_node.set_op("Const");
+  auto* attr = const_node.mutable_attr();
+  (*attr)["dtype"].set_type(dtype);
+  auto* tensor = (*attr)["value"].mutable_tensor();
+  *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype);
+  return const_node;
+}
 }  // namespace
 
 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -429,18 +479,22 @@
   // perform shape inference on the function body.
   //
   // Propagate shape information of final function body node
-  // to function node `node`.
+  // to function node `function_node`.
   //
-  // In the event of an error, UpdateNode will simply set `node`'s
+  // In the event of an error, UpdateNode will simply set `function_node`'s
   // output shape to be Unknown.
-  Status UpdateFunction(const NodeDef* node) {
-    auto it = fun_to_grappler_function_item_.find(node->op());
+  Status UpdateFunction(const NodeDef* function_node) {
+    auto it = fun_to_grappler_function_item_.find(function_node->op());
     if (it == fun_to_grappler_function_item_.end()) {
       return errors::InvalidArgument(
-          node->op(), " was not previously added to SymbolicShapeRefiner.");
+          function_node->op(),
+          " was not previously added to SymbolicShapeRefiner.");
     }
 
-    GrapplerFunctionItem& grappler_function_item = it->second;
+    // Copy (not reference) so that changes we make here (e.g., replacing
+    // Placeholder with Const) don't affect one in
+    // fun_to_grappler_function_item_.
+    GrapplerFunctionItem grappler_function_item = it->second;
     GraphView gv(&grappler_function_item.graph);
 
     // Forward shapes from function input nodes to argument nodes.
@@ -453,7 +507,7 @@
             "supported.");
       }
       NodeDef* fun_node = gv.GetNode(fun_input.input_name);
-      const string& input = node->input(i);
+      const string& input = function_node->input(i);
       const string& node_name = NodeName(input);
 
       if (IsControlInput(input)) {
@@ -478,17 +532,48 @@
       TensorShapeProto proto;
       const auto& handle = input_inference_context->output(output_port_num);
       input_inference_context->ShapeHandleToProto(handle, &proto);
+      // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+      for (int i = 0; i < proto.dim_size(); i++) {
+        if (proto.dim(i).size() < -1) {
+          proto.mutable_dim(i)->set_size(-1);
+        }
+      }
       *attr_output_shape.mutable_shape() = proto;
       (*fun_node->mutable_attr())["shape"] = attr_output_shape;
     }
 
+    // Replace input Placeholders with Consts, if values are known. Note that
+    // we don't check exceptions here as it's done in the above loop.
+    auto* ctx = GetNodeContext(function_node);
+    auto* ic = ctx->inference_context.get();
+    for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+      const string& input = function_node->input(i);
+      const string& node_name = NodeName(input);
+      NodeDef* input_node = graph_.GetNode(node_name);
+      if (IsConstant(*input_node)) {
+        TF_CHECK_OK(
+            ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+      } else if (ic->input_tensors_as_shapes().size() > i &&
+                 IsShapeFullyDefinedIntegerVectorOrScalar(
+                     ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+                     ctx->input_types[i])) {
+        // We have fully defined input_tensors_as_shapes for this input; use it
+        // as a const input to the function node.
+        NodeDef const_input_node = MakeConstNodeDefFromShape(
+            ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+            ctx->input_types[i]);
+        TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
+                                          &grappler_function_item));
+      }
+    }
+
     // Perform inference on function body.
     GraphProperties gp(grappler_function_item);
     TF_RETURN_IF_ERROR(gp.InferStatically(true));
 
     // Add return nodes for output shapes.
-    auto ic = GetContext(node);
     int output = 0;
+    ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
     for (auto const& out_arg : grappler_function_item.outputs()) {
       if (out_arg.output_tensors.size() > 1) {
         // TODO(jmdecker): Handle case of multiple output tensors
@@ -505,8 +590,9 @@
 
       const NodeDef* retnode = gv.GetNode(node_name);
       if (retnode == nullptr) {
-        return errors::FailedPrecondition("Unable to find return node ",
-                                          node_name, " for ", node->name());
+        return errors::FailedPrecondition(
+            "Unable to find return function_node ", node_name, " for ",
+            function_node->name());
       }
 
       auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -520,6 +606,14 @@
       ShapeHandle out;
       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
       ic->set_output(output, out);
+      if (outprop.has_value()) {
+        // Forward tensor value to output_tensors_as_shape.
+        Tensor tensor;
+        if (tensor.FromProto(outprop.value())) {
+          MaybeSetTensorValueToShape(ic, tensor,
+                                     &ctx->output_tensors_as_shapes[output]);
+        }
+      }
       output++;
     }
 
@@ -562,21 +656,9 @@
           if (const_values[dst_input].FromProto(
                   input->attr().at("value").tensor())) {
             input_tensors[dst_input] = &const_values[dst_input];
-            // Integer tensors of rank one can also be interpreted as a shape
-            // provided all their values are >= -1.
-            if (const_values[dst_input].dims() == 1 &&
-                (const_values[dst_input].dtype() == DT_INT32 ||
-                 const_values[dst_input].dtype() == DT_INT64)) {
-              ShapeHandle tensor_shape = inference_context->Vector(
-                  const_values[dst_input].NumElements());
-              ShapeHandle shp;
-              if (inference_context
-                      ->MakeShapeFromTensor(input_tensors[dst_input],
-                                            tensor_shape, &shp)
-                      .ok()) {
-                input_tensors_as_shapes[dst_input] = shp;
-              }
-            }
+            MaybeSetTensorValueToShape(inference_context,
+                                       const_values[dst_input],
+                                       &input_tensors_as_shapes[dst_input]);
           }
         } else if (IsRank(*input)) {
           if (c->inference_context->RankKnown(c->inference_context->input(0))) {
@@ -671,11 +753,13 @@
       // true, as the updates to the call node will have changed, even if it's
       // the same function being called twice with the same input shapes.
       // Example: simple_function.pbtxt
-      if (UpdateFunction(node).ok()) {
+      auto s = UpdateFunction(node);
+      if (s.ok()) {
         return Status::OK();
       } else {
         VLOG(1) << "UpdateFunction failed for " << node->op()
-                << ". Defaulting to ShapeUnknown.";
+                << ". Defaulting to ShapeUnknown.\n"
+                << s.ToString();
       }
     }
 
@@ -942,13 +1026,25 @@
                                                 : t->scalar<int64>()();
             dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
           } else {
-            dims.push_back(ic->UnknownDim());
+            // Don't have tensor value, but use input_tensors_as_shapes, if
+            // possible.
+            const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
+            if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
+                ic->ValueKnown(ic->Dim(shape_handle, 0))) {
+              dims.push_back(ic->Dim(shape_handle, 0));
+            } else {
+              dims.push_back(ic->UnknownDim());
+            }
           }
         }
         if (valid) {
           c->output_tensors_as_shapes.resize(1);
           c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
         }
+      } else if (IsIdentity(node)) {
+        // Pass input_tensors_as_shapes to output_tensors_as_shapes.
+        c->output_tensors_as_shapes.resize(1);
+        c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
       } else if (IsSlice(node)) {
         ShapeHandle input = ic->input_tensors_as_shapes()[0];
         bool valid = ic->RankKnown(input);
@@ -1053,6 +1149,46 @@
   }
 
  private:
+  bool IsIntegerVector(const Tensor& tensor) {
+    if (tensor.dims() == 1 &&
+        (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
+      return true;
+    }
+    return false;
+  }
+
+  bool IsIntegerScalar(const Tensor& tensor) {
+    if (tensor.dims() == 0 &&
+        (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
+        tensor.NumElements() == 1) {
+      return true;
+    }
+    return false;
+  }
+
+  void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
+                                  ShapeHandle* tensors_as_shapes) {
+    // Integer tensors of rank one can also be interpreted as a shape
+    // provided all their values are >= -1.
+    if (IsIntegerVector(tensor)) {
+      ShapeHandle tensor_shape = ic->Vector(tensor.NumElements());
+      ShapeHandle shp;
+      // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1).
+      if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) {
+        *tensors_as_shapes = shp;
+      }
+    } else if (IsIntegerScalar(tensor)) {
+      // Scalar constant.
+      int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
+                                               : tensor.flat<int64>()(0);
+      // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+      // It's a limitation as we use ShapeHandle as a means to pass values.
+      if (value >= -1) {
+        *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
+      }
+    }
+  }
+
   const GraphView& graph_;
   int graph_def_version_;
   std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
@@ -1528,6 +1664,8 @@
       continue;
     }
 
+    auto* ic = ctx->inference_context.get();
+
     // Fill input properties.
     {
       auto& input_properties = input_properties_[node.name()];
@@ -1535,19 +1673,26 @@
       // Should always be empty, node names in graph are supposed to be unique.
       CHECK_EQ(input_properties.size(), 0);
 
-      input_properties.resize(ctx->inference_context->num_inputs());
+      input_properties.resize(ic->num_inputs());
       GraphView::InputPort input(&node, -1);
-      for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
-        shape_manager.AsTensorProperties(ctx->inference_context->input(i),
-                                         ctx->input_types[i],
+      for (int i = 0; i < ic->num_inputs(); ++i) {
+        shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i],
                                          &input_properties[i]);
         input.port_id = i;
         GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
-        if (!IsConstant(*fanin.node)) {
-          continue;
+        // Export tensor value (either const tensor or input_tensors_as_shapes)
+        // to input_properties.value.
+        if (IsConstant(*fanin.node)) {
+          const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+          *input_properties[i].mutable_value() = raw_val;
+        } else if (ic->input_tensors_as_shapes().size() > i &&
+                   IsShapeFullyDefinedIntegerVectorOrScalar(
+                       ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+                       ctx->input_types[i])) {
+          *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
+              ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+              ctx->input_types[i]);
         }
-        const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
-        *input_properties[i].mutable_value() = raw_val;
       }
     }
 
@@ -1558,11 +1703,23 @@
       // Should always be empty, node names in graph are supposed to be unique.
       CHECK_EQ(output_properties.size(), 0);
 
-      output_properties.resize(ctx->inference_context->num_outputs());
-      for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
-        shape_manager.AsTensorProperties(ctx->inference_context->output(i),
-                                         ctx->output_types[i],
+      output_properties.resize(ic->num_outputs());
+      for (int i = 0; i < ic->num_outputs(); ++i) {
+        shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i],
                                          &output_properties[i]);
+        // Export tensor value (either const tensor or input_tensors_as_shapes)
+        // to output_properties.value.
+        if (IsConstant(node)) {
+          const TensorProto& raw_val = node.attr().at("value").tensor();
+          *output_properties[i].mutable_value() = raw_val;
+        } else if (ctx->output_tensors_as_shapes.size() > i &&
+                   IsShapeFullyDefinedIntegerVectorOrScalar(
+                       ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+                       ctx->output_types[i])) {
+          *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
+              ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+              ctx->output_types[i]);
+        }
       }
     }
   }
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 8938b7c..362092a 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -44,6 +44,30 @@
     // Provision a single machine with 3 cpu cores
     cluster_.reset(new SingleMachine(5 * 60, 3, 0));
     TF_CHECK_OK(cluster_->Provision());
+
+    // This function is simply
+    // out = Fill(shape, value), but
+    // Fill requires values in the shape input, not just shape of it, to infer
+    // output shape.
+    auto f = FunctionDefHelper::Create(
+        // Name
+        "MyFillFunc",
+        // Inputs
+        {"shape: int32", "value: float"},
+        // Outputs
+        {"out: float"},
+        // Attrs
+        {},
+        // Nodes
+        {
+            {{"a"},
+             "Fill",
+             {"shape", "value"},
+             {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+        },
+        // Returns
+        {{"out", "a:output:0"}});
+    function_lib_.add_function()->Swap(&f);
   }
 
   void TearDown() override {
@@ -69,7 +93,29 @@
     return s;
   }
 
+  // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
+  // ones.
+  void ExpectTensorValues(const std::vector<int64>& expected,
+                          const TensorProto& tensor_proto_to_compare) {
+    Tensor tensor;
+    EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
+    EXPECT_EQ(expected.size(), tensor.NumElements());
+    // We're interested in only integer tensors as only shapes are exported as
+    // graph properties values.
+    CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
+    if (tensor.dtype() == DT_INT32) {
+      for (int i = 0; i < tensor.NumElements(); i++) {
+        EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
+      }
+    } else {
+      for (int i = 0; i < tensor.NumElements(); i++) {
+        EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
+      }
+    }
+  }
+
   std::unique_ptr<SingleMachine> cluster_;
+  FunctionDefLibrary function_lib_;
 };
 
 TEST_F(GraphPropertiesTest, StaticProperties) {
@@ -785,7 +831,220 @@
   EXPECT_EQ("float: [128,256]", PropToString(prop));
 }
 
-TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
+TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
+  Output a1 = ops::Identity(s.WithOpName("a1"), a);
+  Output b = ops::Const(s.WithOpName("b"), 99, {});
+  Output b1 = ops::Identity(s.WithOpName("b1"), b);
+  Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
+  Output c1 = ops::Identity(s.WithOpName("c1"), c);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+
+  // Check output shapes.
+  EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
+  EXPECT_EQ("int32: [2]",
+            PropToString(properties.GetOutputProperties("a1")[0]));
+  EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
+  EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
+  EXPECT_EQ("int32: [4,4,4]",
+            PropToString(properties.GetOutputProperties("c")[0]));
+  EXPECT_EQ("int32: [4,4,4]",
+            PropToString(properties.GetOutputProperties("c1")[0]));
+
+  // Check has_value.
+  EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
+  EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
+  EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
+  EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
+  EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
+  EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
+  EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
+  EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
+  // Note that we propagate tensro value of only 1D vector and scalar.
+  EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value());
+
+  // Check values.
+  ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
+  ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
+  ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
+  ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
+  ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
+  ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
+  std::vector<int64> c_values;
+  for (int i = 0; i < 4 * 4 * 4; i++) {
+    c_values.push_back(1);
+  }
+  ExpectTensorValues({c_values},
+                     properties.GetOutputProperties("c")[0].value());
+  ExpectTensorValues({c_values},
+                     properties.GetInputProperties("c1")[0].value());
+  // No output value for c1, as it's neither 1D vector nor scalar.
+}
+
+TEST_F(GraphPropertiesTest, IdentityPassingShape) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Const(s.WithOpName("a"), 5, {2});
+  Output b = ops::Identity(s.WithOpName("b"), a);
+  Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
+  // Fill needs not only e's shape but also the value of e to figure out output
+  // shape; hence, Identity op (b) should pass a's value as
+  // output_tensors_as_shape.
+  Output d = ops::Fill(s.WithOpName("fill"), b, c);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+  const auto out_props = properties.GetOutputProperties("fill");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithConstInput) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Const(s.WithOpName("a"), 1, {});
+  Output b = ops::Const(s.WithOpName("b"), 2, {});
+  Output c = ops::Const(s.WithOpName("c"), 3, {});
+  Output d = ops::Const(s.WithOpName("d"), 4, {});
+  // Note ops::Stack instantiates Pack op.
+  Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+  // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+  Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+  // Fill needs not only e's shape but also its value to figure out output
+  // shape.
+  Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+  const auto out_props = properties.GetOutputProperties("fill");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
+  // from Const.
+  // If output_tensors_as_shape is not not set for those Shape ops or Pack op
+  // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
+  // hence, its output shape becomes unknown.
+  Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
+  Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
+  Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
+  Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
+  Output a = ops::Identity(s.WithOpName("a"), a0);
+  Output b = ops::Identity(s.WithOpName("b"), b0);
+  Output c = ops::Identity(s.WithOpName("c"), c0);
+  Output d = ops::Identity(s.WithOpName("d"), d0);
+  // Note ops::Stack instantiates Pack op.
+  Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+  // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+  Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+  // Fill needs not only e's shape but also its value to figure out output
+  // shape.
+  Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+  const auto out_props = properties.GetOutputProperties("fill");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+  Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
+  Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+  auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+                                         s.graph()->op_registry());
+  tensorflow::Node* func_op;
+  auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+  auto _value = tensorflow::ops::AsNodeOut(s, value);
+  TF_CHECK_OK(
+      builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+  const auto out_props = properties.GetOutputProperties("MyFillFunc");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
+  // Same to FunctionWithConstInput, but function inputs are Identity of Const,
+  // so tensor shapes, not tensor value, should be used as Const input to
+  // function.
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+  Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
+  Output shape = ops::Identity(s.WithOpName("shape"), shape_);
+  Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+  auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+                                         s.graph()->op_registry());
+  tensorflow::Node* func_op;
+  auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+  auto _value = tensorflow::ops::AsNodeOut(s, value);
+  TF_CHECK_OK(
+      builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(false));
+  const auto out_props = properties.GetOutputProperties("MyFillFunc");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
+  FunctionDefLibrary library;
+  *library.add_function() = FunctionDefHelper::Create(
+      "MyFunc",                                                   // Name
+      {"x: int32"},                                               // Inputs
+      {"out: int32"},                                             // Outputs
+      {},                                                         // Attrs
+      {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}},  // Nodes
+      {{"out", "a:output:0"}});                                   // Returns
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+
+  // MyFunc takes Const (shape) and passes it with Identity. Expect function
+  // output has the same shape as well as value (output_tensors_as_shape) as
+  // input Const tensor.
+  Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
+  auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+  auto builder =
+      tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+  tensorflow::Node* func_op;
+  TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphProperties properties(item);
+  TF_CHECK_OK(properties.InferStatically(true));
+  const auto out_props = properties.GetOutputProperties("MyFunc");
+  const OpInfo::TensorProperties out_prop0 = out_props[0];
+  EXPECT_EQ("int32: [2]", PropToString(out_prop0));
+  EXPECT_TRUE(out_prop0.has_value());
+  ExpectTensorValues({5, 7}, out_prop0.value());
+  ExpectTensorValues({5, 7},
+                     properties.GetInputProperties("MyFunc")[0].value());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
   // Create graph with a function that takes a scalar value so that we use
   // Placeholder with scalar as for input to the function shape inference.
   // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
@@ -818,7 +1077,7 @@
 
   // MyFunc output shouldn't be unknown rank.
   GraphProperties properties(item);
-  TF_CHECK_OK(properties.InferStatically(false));
+  TF_CHECK_OK(properties.InferStatically(true));
   const auto out_props = properties.GetOutputProperties("MyFunc");
   const OpInfo::TensorProperties out_prop0 = out_props[0];
   EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
@@ -856,18 +1115,10 @@
   EXPECT_EQ(2, in_props.size());
 
   const OpInfo::TensorProperties& in_prop = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop.dtype());
-  EXPECT_FALSE(in_prop.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop.shape().dim_size());
-  EXPECT_EQ(1, in_prop.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_FALSE(in_prop1.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
 }
 
 TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
@@ -882,51 +1133,25 @@
   EXPECT_EQ(2, out_props.size());
 
   const OpInfo::TensorProperties& out_prop0 = out_props[0];
-  EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
-  EXPECT_EQ(4, out_prop0.shape().dim_size());
-  EXPECT_EQ(128, out_prop0.shape().dim(0).size());
-  EXPECT_EQ(112, out_prop0.shape().dim(1).size());
-  EXPECT_EQ(112, out_prop0.shape().dim(2).size());
-  EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+  EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
 
   const OpInfo::TensorProperties& out_prop1 = out_props[1];
-  EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
-  EXPECT_EQ(128, out_prop1.shape().dim(0).size());
-  EXPECT_EQ(112, out_prop1.shape().dim(1).size());
-  EXPECT_EQ(112, out_prop1.shape().dim(2).size());
-  EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+  EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
 
   const auto in_props = properties.GetInputProperties("y0");
   EXPECT_EQ(4, in_props.size());
 
   const OpInfo::TensorProperties& in_prop0 = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
-  EXPECT_EQ(1, in_prop0.shape().dim_size());
-  EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+  EXPECT_EQ("float: [64]", PropToString(in_prop0));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_EQ(4, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(1, in_prop1.shape().dim(1).size());
-  EXPECT_EQ(24, in_prop1.shape().dim(2).size());
-  EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+  EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
 
   const OpInfo::TensorProperties& in_prop2 = in_props[2];
-  EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
-  EXPECT_EQ(4, in_prop2.shape().dim_size());
-  EXPECT_EQ(128, in_prop2.shape().dim(0).size());
-  EXPECT_EQ(224, in_prop2.shape().dim(1).size());
-  EXPECT_EQ(224, in_prop2.shape().dim(2).size());
-  EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+  EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
 
   const OpInfo::TensorProperties& in_prop3 = in_props[3];
-  EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
-  EXPECT_EQ(4, in_prop3.shape().dim_size());
-  EXPECT_EQ(7, in_prop3.shape().dim(0).size());
-  EXPECT_EQ(7, in_prop3.shape().dim(1).size());
-  EXPECT_EQ(3, in_prop3.shape().dim(2).size());
-  EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+  EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
 }
 
 TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
@@ -986,18 +1211,10 @@
   EXPECT_EQ(2, in_props.size());
 
   const OpInfo::TensorProperties& in_prop = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop.dtype());
-  EXPECT_FALSE(in_prop.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop.shape().dim_size());
-  EXPECT_EQ(1, in_prop.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_FALSE(in_prop1.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
 }
 
 TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
@@ -1022,27 +1239,16 @@
   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
   const OpInfo::TensorProperties& out_prop = out_props[0];
   EXPECT_EQ(DT_FLOAT, out_prop.dtype());
-  EXPECT_FALSE(out_prop.shape().unknown_rank());
-  EXPECT_EQ(2, out_prop.shape().dim_size());
-  EXPECT_EQ(1, out_prop.shape().dim(0).size());
-  EXPECT_EQ(2, out_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(out_prop));
 
   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
   EXPECT_EQ(2, in_props.size());
 
   const OpInfo::TensorProperties& in_prop = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop.dtype());
-  EXPECT_FALSE(in_prop.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop.shape().dim_size());
-  EXPECT_EQ(1, in_prop.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_FALSE(in_prop1.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
 }
 
 TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
@@ -1066,28 +1272,16 @@
   TF_CHECK_OK(properties.InferStatically(false));
   const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
   const OpInfo::TensorProperties& out_prop = out_props[0];
-  EXPECT_EQ(DT_FLOAT, out_prop.dtype());
-  EXPECT_FALSE(out_prop.shape().unknown_rank());
-  EXPECT_EQ(2, out_prop.shape().dim_size());
-  EXPECT_EQ(1, out_prop.shape().dim(0).size());
-  EXPECT_EQ(2, out_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(out_prop));
 
   const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
   EXPECT_EQ(2, in_props.size());
 
   const OpInfo::TensorProperties& in_prop = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop.dtype());
-  EXPECT_FALSE(in_prop.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop.shape().dim_size());
-  EXPECT_EQ(1, in_prop.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_FALSE(in_prop1.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
 }
 
 TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
@@ -1115,28 +1309,16 @@
   TF_CHECK_OK(properties.InferStatically(false));
   const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
   const OpInfo::TensorProperties& out_prop = out_props[0];
-  EXPECT_EQ(DT_FLOAT, out_prop.dtype());
-  EXPECT_FALSE(out_prop.shape().unknown_rank());
-  EXPECT_EQ(2, out_prop.shape().dim_size());
-  EXPECT_EQ(1, out_prop.shape().dim(0).size());
-  EXPECT_EQ(2, out_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(out_prop));
 
   const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
   EXPECT_EQ(2, in_props.size());
 
   const OpInfo::TensorProperties& in_prop = in_props[0];
-  EXPECT_EQ(DT_FLOAT, in_prop.dtype());
-  EXPECT_FALSE(in_prop.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop.shape().dim_size());
-  EXPECT_EQ(1, in_prop.shape().dim(0).size());
-  EXPECT_EQ(2, in_prop.shape().dim(1).size());
+  EXPECT_EQ("float: [1,2]", PropToString(in_prop));
 
   const OpInfo::TensorProperties& in_prop1 = in_props[1];
-  EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
-  EXPECT_FALSE(in_prop1.shape().unknown_rank());
-  EXPECT_EQ(2, in_prop1.shape().dim_size());
-  EXPECT_EQ(1, in_prop1.shape().dim(0).size());
-  EXPECT_EQ(3, in_prop1.shape().dim(1).size());
+  EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
 }
 
 TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
index 26d38a4..9762634 100644
--- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
@@ -138,7 +138,7 @@
   // The entries are owned by collation_map_, so must be removed from
   // ordered_collation_ before removing them from collation_map_.
   struct ReverseLessByCount {
-    bool operator()(CollationEntry* left, CollationEntry* right) {
+    bool operator()(CollationEntry* left, CollationEntry* right) const {
       return left->count > right->count;  // Reverse order.
     }
   };
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088..e78239b 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,16 +135,37 @@
 
 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
 
-bool IsElementWiseMonotonic(const NodeDef& node) {
-  static const std::unordered_set<string>* element_wise_monotonic_ops =
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+  static const std::unordered_set<string>* monotonic_non_decreasing_ops =
       CHECK_NOTNULL((new std::unordered_set<string>{
-          "Relu",
-          "Relu6",
-          "Sigmoid",
-          "Sqrt",
-          "Tanh",
+          "Asinh", "Atanh",   "Ceil",  "Elu",  "Erf",  "Exp",   "Expm1",
+          "Floor", "Log",     "Log1p", "Relu", "Relu", "Relu6", "Rint",
+          "Selu",  "Sigmoid", "Sign",  "Sinh", "Sqrt", "Tanh",
       }));
-  return element_wise_monotonic_ops->count(node.op()) > 0;
+  static const std::unordered_set<string>* monotonic_non_increasing_ops =
+      CHECK_NOTNULL((new std::unordered_set<string>{
+          "Inv",
+          "Reciprocal",
+          "Erfc",
+          "Rsqrt",
+          "Neg",
+      }));
+  if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+    if (is_non_decreasing) {
+      *is_non_decreasing = true;
+    }
+    return true;
+  } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+    if (is_non_decreasing) {
+      *is_non_decreasing = false;
+    }
+    return true;
+  }
+  return false;
 }
 
 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 9443926..25ab6b6 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,7 +55,7 @@
 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
 bool IsDequeueOp(const NodeDef& node);
 bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
 bool IsEluGrad(const NodeDef& node);
 bool IsEnter(const NodeDef& node);
 bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index a24004d..f094c15 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -846,3 +846,68 @@
         "//third_party/eigen3",
     ],
 )
+
+cc_library(
+    name = "function_api_info",
+    srcs = ["function_api_info.cc"],
+    hdrs = ["function_api_info.h"],
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+    ],
+)
+
+tf_cc_test(
+    name = "function_api_info_test",
+    size = "small",
+    srcs = ["function_api_info_test.cc"],
+    deps = [
+        ":function_api_info",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+    ],
+)
+
+cc_library(
+    name = "experimental_implementation_selector",
+    srcs = ["experimental_implementation_selector.cc"],
+    hdrs = ["experimental_implementation_selector.h"],
+    deps = [
+        ":custom_graph_optimizer",
+        ":custom_graph_optimizer_registry",
+        ":function_api_info",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
+        "//tensorflow/core/grappler:utils",
+        "//tensorflow/core/grappler/costs:graph_properties",
+    ],
+)
+
+tf_cc_test(
+    name = "experimental_implementation_selector_test",
+    size = "small",
+    srcs = ["experimental_implementation_selector_test.cc"],
+    deps = [
+        ":custom_graph_optimizer",
+        ":custom_graph_optimizer_registry",
+        ":experimental_implementation_selector",
+        ":function_api_info",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+        "//tensorflow/core/grappler/utils:grappler_test",
+    ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 4fed88d..11ce121 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1121,11 +1121,8 @@
   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
     NodeDef* tail = node;
-    // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
-    if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
-      tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
-                                      *ctx().nodes_to_preserve);
-    }
+    tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+                                    *ctx().nodes_to_preserve);
     NodeDef* first_transpose;
     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
 
@@ -2706,8 +2703,9 @@
     // 0. inner_function is not in the preserve set,
     // 1. inner_function's Op is element-wise monotonic
     // 2. inner_function's output is not being consumed elsewhere.
+    bool is_non_decreasing = false;
     if (!IsInPreserveSet(*inner_function) &&
-        IsElementWiseMonotonic(*inner_function) &&
+        IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
         ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
       // Swap the first inputs of the inner function Op & the reduction Op.
       NodeDef* inner_input;
@@ -2719,7 +2717,12 @@
       UpdateConsumers(reduction_node, inner_function->name());
       ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
                                   reduction_node->name());
-
+      if (!is_non_decreasing) {
+        // Flip Min<->Max if the function is non-increasing, e.g.
+        // Max(Neg(x)) = Neg(Min(x)).
+        const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+        reduction_node->set_op(opposite);
+      }
       AddToOptimizationQueue(reduction_node);
       AddToOptimizationQueue(inner_function);
       AddToOptimizationQueue(inner_input);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index bfccc0a..39517ed 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3248,6 +3248,48 @@
   VerifyGraphsMatch(item.graph, output, __LINE__);
 }
 
+TEST_F(ArithmeticOptimizerTest,
+       OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+  Output neg = ops::Neg(s.WithOpName("neg"), x);
+  Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+  Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+  GrapplerItem item;
+  item.fetch = {"final_out"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+  EXPECT_EQ(1, tensors_expected.size());
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+  OptimizeAndPrune(&optimizer, &item, &output);
+  auto tensors = EvaluateNodes(output, item.fetch);
+  EXPECT_EQ(1, tensors.size());
+
+  test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+  EXPECT_EQ(item.graph.node_size(), output.node_size());
+  // Check if the inputs are switched
+  int required_node_count = 0;
+  for (int i = 0; i < output.node_size(); ++i) {
+    const NodeDef& node = output.node(i);
+    if (node.name() == "neg") {
+      EXPECT_EQ("Neg", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("reduce_max", node.input(0));
+      ++required_node_count;
+    } else if (node.name() == "reduce_max") {
+      EXPECT_EQ("Min", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("x", node.input(0));
+      ++required_node_count;
+    }
+  }
+  EXPECT_EQ(2, required_node_count);
+}
+
 TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
 
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
new file mode 100644
index 0000000..eeea269
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
+
+Status ExperimentalImplementationSelector::LoadFunctions(
+    const GraphDef& graph) {
+  lib_info_.reset(new FunctionLibraryApiInfo);
+  TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
+  return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
+    NodeDef* node_def) const {
+  const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
+  if (info == nullptr) {
+    // A regular op, or a function which has no interface.
+    return Status::OK();
+  }
+
+  string task, device;
+  if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
+    return errors::Internal("Could not split device name:", node_def->device());
+  }
+  VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
+          << " = (" << task << ", " << device << ")";
+  DeviceNameUtils::ParsedName parsed_name;
+  DeviceNameUtils::ParseLocalName(device, &parsed_name);
+
+  string best_function_name;
+  lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+                                   &best_function_name);
+  if (node_def->op() != best_function_name) {
+    // The current implementation is not the best, swap the op to the best one.
+    // There will be duplicates in the graph and they will be pruned by other
+    // grappler plugin since no other node is using their output as inputs.
+    // TODO(scottzhu): Update the tf.eager.defun to register functions without
+    // having to call them with input data. That will reduce the graph size and
+    // save the work for prune them.
+    node_def->set_op(best_function_name);
+  }
+  return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::SelectImplementation(
+    GraphDef* graph) const {
+  for (int k = 0; k < graph->node_size(); ++k)
+    TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
+
+  return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::Optimize(Cluster* cluster,
+                                                    const GrapplerItem& item,
+                                                    GraphDef* optimized_graph) {
+  *optimized_graph = item.graph;
+  TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
+  return SelectImplementation(optimized_graph);
+}
+
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
new file mode 100644
index 0000000..82f7473
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// -- EXPERIMENTAL --
+// This transformation replaces function calls by the appropriate function
+// definition based on properties of the runtime system. For instance,
+// we may choose one implementation over another if we have a GPU with
+// enough memory available.
+//
+// It is a way for the programmer to specify alternative implementations
+// of the same functionality in the graph, and let TensorFlow pick the
+// most appropriate one at runtime.
+//
+// For instance, the python code might specify:
+// @Defun(tf.float32,
+//        experimental_api_implements='plus_one',
+//        experimental_api_preferred_device='GPU')
+// def plus_one_gpu(x): return x + 1.0
+//
+// @Defun(tf.float32,
+//        experimental_api_implements='plus_one')
+// def plus_one_reference_implementation(x): return x + 1.0
+// input = tf.constant(2.0, dtype=tf.float32)
+//
+// z = plus_one_reference_implementation(input)
+// z = plus_one_gpu(input)
+// print(sess.run(z))
+//
+// At runtime, we will trim either `plus_one_gpu` or
+// `plus_one_reference_implementation` based on the availability of the GPU.
+//
+// Available annotations:
+//  - experimental_api_implements(string): all functions mapping to the same
+//    string can be interchanged. For now, all functions must have the same
+//    signature and overloads are not allowed. Defuns within defuns are
+//    allowed.
+//  - experimental_api_preferred_device(string): sets which device is preferred.
+class ExperimentalImplementationSelector : public CustomGraphOptimizer {
+ public:
+  ExperimentalImplementationSelector() = default;
+  ~ExperimentalImplementationSelector() override = default;
+  Status Init(
+      const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+    return Status::OK();
+  }
+  string name() const override {
+    return "experimental_implementation_selector";
+  }
+
+  // This call is not thread-safe.
+  Status Optimize(Cluster* cluster, const GrapplerItem& item,
+                  GraphDef* optimized_graph) override;
+
+  // Does not take any feedback.
+  void Feedback(Cluster* cluster, const GrapplerItem& item,
+                const GraphDef& optimized_graph, double result) override {}
+
+ private:
+  Status LoadFunctions(const GraphDef& graph);
+  Status MaybeOptimizeFunctionCall(NodeDef* node_def) const;
+
+  // Finds all call sites for functions, then replace with the appropriate
+  // implementation.
+  // There are two ways of calling functions:
+  //  1. By specifying an op name as a function name, and
+  //  2. Via the functional interface, where the function name appears as an
+  //  Attr.
+  //
+  // There may be multiple call sites for a given function. The function body
+  // may call into another function, so a function might have to be duplicated.
+  // For simplicity, we do not change function bodies. Also, we do not change
+  // gradients.
+  Status SelectImplementation(GraphDef* graph) const;
+
+  std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector);
+};
+
+}  // namespace grappler
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000..2368e57
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,139 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  std::unique_ptr<CustomGraphOptimizer> optimizer =
+      CustomGraphOptimizerRegistry::CreateByNameOrNull(
+          "ExperimentalImplementationSelector");
+  ASSERT_NE(nullptr, optimizer);
+  TF_ASSERT_OK(optimizer->Init());
+
+  GraphDef output;
+  const Status status = optimizer->Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  // This is a trivial graph so there is nothing to update.
+  EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+  using test::function::NDef;
+  auto cpu_def = test::function::XTimesTwo();
+  auto* func_attr = cpu_def.mutable_attr();
+  (*func_attr)["experimental_api_implements"].set_s("times_two");
+  (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+  auto gpu_def = test::function::XAddX();
+  auto* func2_attr = gpu_def.mutable_attr();
+  (*func2_attr)["experimental_api_implements"].set_s("times_two");
+  (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+  ExperimentalImplementationSelector optimizer;
+  GraphDef output;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+       NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+       NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+       NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+       NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+      // FunctionLib
+      {cpu_def, gpu_def});
+
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+  EXPECT_EQ(output.node_size(), 5);
+  for (const NodeDef& node : output.node()) {
+    if (node.name() == "y1") {
+      // Make sure the implementation has been swapped to use the GPU version.
+      EXPECT_EQ("XAddX", node.op());
+    } else if (node.name() == "y2") {
+      // Make sure the implementation is not changed.
+      EXPECT_EQ("XTimesTwo", node.op());
+    }
+  }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+  using test::function::NDef;
+  auto cpu_def = test::function::XTimesTwo();
+  auto* func_attr = cpu_def.mutable_attr();
+  (*func_attr)["experimental_api_implements"].set_s("random_boost");
+  (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+  auto gpu_def = test::function::XTimesFour();
+  auto* func2_attr = gpu_def.mutable_attr();
+  (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+  (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+  ExperimentalImplementationSelector optimizer;
+  GraphDef output;
+  GrapplerItem item;
+  item.graph = test::function::GDef(
+      {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+       NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+       NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+      // FunctionLib
+      {cpu_def, gpu_def});
+
+  const Tensor input = test::AsScalar<float>(1.0f);
+  item.fetch = {"z"};
+  item.feed.emplace_back("x", input);
+
+  const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+  test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+                                 test::AsScalar<float>(4.0f));
+
+  TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+  GrapplerItem optimized(item, std::move(output));
+  const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+  test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+                                 test::AsScalar<float>(2.0f));
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc
new file mode 100644
index 0000000..798e0f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+FunctionApiInfo::FunctionApiInfo() {}
+FunctionApiInfo::~FunctionApiInfo() {}
+
+Status FunctionApiInfo::Init(const FunctionDef& function_def) {
+  for (const auto& attr : function_def.attr()) {
+    if (attr.first == "experimental_api_preferred_device") {
+      preferred_device_ = attr.second.s();
+    }
+    if (attr.first == "experimental_api_implements") {
+      interface_name_ = attr.second.s();
+    }
+  }
+  if (interface_name_.empty() && !preferred_device_.empty()) {
+    return errors::InvalidArgument(
+        "Function '", function_def.signature().name(),
+        "' has a preferred device, but does not implement an interface");
+  }
+  return Status::OK();
+}
+
+const string& FunctionApiInfo::preferred_device() const {
+  return preferred_device_;
+}
+
+const string& FunctionApiInfo::interface_name() const {
+  return interface_name_;
+}
+
+FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
+FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
+
+namespace {
+bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) {
+  if (f1.ret().size() != f2.ret().size()) return false;
+  const auto& sig1 = f1.signature();
+  const auto& sig2 = f2.signature();
+  // Functions have positional semantics, so we don't check for names.
+  if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
+  for (int k = 0; k < sig1.input_arg_size(); ++k) {
+    const OpDef::ArgDef& arg1 = sig1.input_arg(k);
+    const OpDef::ArgDef& arg2 = sig2.input_arg(k);
+    if (arg1.type() != arg2.type()) return false;
+    if (arg1.type_attr() != arg2.type_attr()) return false;
+    if (arg1.number_attr() != arg2.number_attr()) return false;
+    if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
+    if (arg1.is_ref() != arg2.is_ref()) return false;
+  }
+  return true;
+}
+
+Status ValidateSignature(const string& interface_name,
+                         const std::vector<const FunctionDef*>& equiv_funcs) {
+  if (equiv_funcs.size() < 2) return Status::OK();
+  for (size_t k = 1; k < equiv_funcs.size(); ++k) {
+    if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k]))
+      return errors::InvalidArgument(
+          "Functions '", equiv_funcs[0]->signature().name(), "' and '",
+          equiv_funcs[k]->signature().name(), "' both implement '",
+          interface_name, "' but their signatures do not match.");
+  }
+  return Status::OK();
+}
+
+Status ValidateSignatures(
+    const std::unordered_map<string, std::vector<const FunctionDef*>>&
+        intf_to_func) {
+  for (const auto& item : intf_to_func)
+    TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second));
+  return Status::OK();
+}
+}  // namespace
+
+Status FunctionLibraryApiInfo::Init(
+    const FunctionDefLibrary& function_library) {
+  std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func;
+  for (const auto& function : function_library.function()) {
+    std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
+    TF_RETURN_IF_ERROR(func_info->Init(function));
+    // Ignore the function if it does not implement any interface.
+    if (func_info->interface_name().empty()) continue;
+
+    const string& function_name = function.signature().name();
+    const string& interface_name = func_info->interface_name();
+    func_to_intf_[function_name] = interface_name;
+    intf_to_funcs_[interface_name].emplace_back(function_name);
+    intf_to_func[interface_name].emplace_back(&function);
+    func_info_[function_name] = std::move(func_info);
+  }
+  TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func));
+  return Status::OK();
+}
+
+void FunctionLibraryApiInfo::GetEquivalentImplementations(
+    const string& function_name, std::vector<string>* other_names) const {
+  const auto intf_it = func_to_intf_.find(function_name);
+  // The function does not implement any interface.
+  if (intf_it == func_to_intf_.end()) return;
+  CHECK(!intf_it->second.empty()) << "Function " << function_name
+                                  << "should at least implement 1 interface.";
+  const auto it = intf_to_funcs_.find(intf_it->second);
+  CHECK(it != intf_to_funcs_.end())
+      << "Function " << function_name << " maps to " << intf_it->second
+      << " but no reverse mapping was found";
+  CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty";
+  other_names->reserve(it->second.size() - 1);
+  for (const auto& other_name : it->second) {
+    if (other_name == function_name) continue;
+    other_names->emplace_back(other_name);
+  }
+}
+
+void FunctionLibraryApiInfo::GetBestImplementation(
+    const string& function_name, const string& device,
+    string* best_func_name) const {
+  CHECK(best_func_name != nullptr);
+  const auto func_it = func_to_intf_.find(function_name);
+  if (func_it == func_to_intf_.end()) return;
+
+  const auto it = intf_to_funcs_.find(func_it->second);
+  // No function found for the given interface.
+  if (it == intf_to_funcs_.end()) return;
+  for (const auto& func_name : it->second) {
+    const auto func_api_info = func_info_.find(func_name)->second.get();
+    if (func_api_info->preferred_device() == device) {
+      best_func_name->assign(func_name);
+      return;
+    }
+  }
+  // Didn't find a function with the match device name, choose the first one
+  // among all the available functions.
+  best_func_name->assign(it->second.front());
+}
+
+const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
+    const string& function_name) const {
+  const auto it = func_info_.find(function_name);
+  if (it == func_info_.end()) return nullptr;
+  return it->second.get();
+}
+
+}  // end namespace grappler
+}  // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h
new file mode 100644
index 0000000..412687c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+class FunctionApiInfo {
+ public:
+  FunctionApiInfo();
+  virtual ~FunctionApiInfo();
+
+  Status Init(const FunctionDef& function_def);
+
+  const string& interface_name() const;
+  const string& preferred_device() const;
+
+ private:
+  string interface_name_;
+  string preferred_device_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
+};
+
+// A collection of information for function and the interface it implements.
+// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
+// functions could implement the same interface with different behavior based on
+// different hardware condition and limits,
+// eg F1 = math_ops.add(math_ops.add(x, x), y), or
+//    F2 = math_ops.add(math_ops.matmul(x, 2), y).
+class FunctionLibraryApiInfo {
+ public:
+  FunctionLibraryApiInfo();
+  virtual ~FunctionLibraryApiInfo();
+  // Populate the internal field for the functions within the function_library.
+  Status Init(const FunctionDefLibrary& function_library);
+
+  void GetEquivalentImplementations(const string& function_name,
+                                    std::vector<string>* other_names) const;
+
+  void GetBestImplementation(const string& function_name, const string& device,
+                             string* best_func_name) const;
+
+  const FunctionApiInfo* GetApiInfo(const string& function_name) const;
+
+ private:
+  // Map between function name to function details.
+  std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
+  // Map between function name to interface name.
+  std::unordered_map<string, string> func_to_intf_;
+  // Map between interface name to function names.
+  std::unordered_map<string, std::vector<string>> intf_to_funcs_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
+};
+
+}  // end namespace grappler
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
new file mode 100644
index 0000000..582890d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+void SetArg(const string& name, const string& type_name,
+            OpDef::ArgDef* arg_def) {
+  arg_def->set_name(name);
+  arg_def->set_type_attr(type_name);
+}
+
+typedef std::pair<string, string> ArgSpec;  // name, type.
+
+void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) {
+  for (const auto& arg_spec : args_spec)
+    SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
+  SetArg("output", "float32", sig->add_output_arg());
+}
+
+void PopulateFunction(const string& name, const string& api_interface_name,
+                      const string& preferred_device,
+                      const std::vector<ArgSpec>& input_args,
+                      FunctionDef* func_def) {
+  OpDef* sig = func_def->mutable_signature();
+  sig->set_name(name);
+
+  SetArgs(input_args, sig);
+
+  if (!api_interface_name.empty() || !preferred_device.empty()) {
+    auto* func_attr = func_def->mutable_attr();
+    if (!api_interface_name.empty())
+      (*func_attr)["experimental_api_implements"].set_s(api_interface_name);
+    if (!preferred_device.empty())
+      (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device);
+  }
+}
+
+void PopulateSampleLibrary(const bool mismatch_args,
+                           FunctionDefLibrary* func_lib) {
+  const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
+  const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
+                                             {"in2", "int32"}};
+  PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args,
+                   func_lib->add_function());
+  PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
+                   mismatch_args ? func_wrong_args : func_args,
+                   func_lib->add_function());
+  PopulateFunction("DoThings", "DoThings", "", func_args,
+                   func_lib->add_function());
+  PopulateFunction("OneOff", "", "", func_args, func_lib->add_function());
+  PopulateFunction("AnotherOneOff", "", "", func_args,
+                   func_lib->add_function());
+}
+
+bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
+                    const string& func_name,
+                    const std::vector<string>& expected_other) {
+  std::vector<string> other_impl;
+  lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
+  const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
+  const std::unordered_set<string> expected(expected_other.begin(),
+                                            expected_other.end());
+  return actual == expected;
+}
+
+bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info,
+                      const string& function_name, const string& device,
+                      const string& expected_function_name) {
+  string best_function_name;
+  lib_api_info.GetBestImplementation(function_name, device,
+                                     &best_function_name);
+
+  return best_function_name == expected_function_name;
+}
+
+string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
+                        const string& func_name) {
+  auto* info = lib_api_info.GetApiInfo(func_name);
+  CHECK_NOTNULL(info);
+  return info->interface_name();
+}
+
+string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
+                          const string& func_name) {
+  auto* info = lib_api_info.GetApiInfo(func_name);
+  CHECK_NOTNULL(info);
+  return info->preferred_device();
+}
+
+TEST(FunctionApiInfoTest, ParseTags) {
+  FunctionDefLibrary func_lib;
+  PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
+  FunctionLibraryApiInfo lib_api_info;
+  TF_ASSERT_OK(lib_api_info.Init(func_lib));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
+  EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
+
+  EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
+  EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
+  EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
+
+  EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
+  EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
+  EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
+
+  EXPECT_TRUE(
+      CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu"));
+  EXPECT_TRUE(
+      CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu"));
+  EXPECT_TRUE(
+      CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu"));
+  EXPECT_TRUE(
+      CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu"));
+
+  EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings"));
+  // TPU impl is not available, choose the first one available which is the CPU.
+  EXPECT_TRUE(
+      CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu"));
+}
+
+TEST(FunctionApiInfoTest, MismatchedArguments) {
+  FunctionDefLibrary func_lib;
+  PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
+  FunctionLibraryApiInfo lib_api_info;
+  const Status ret = lib_api_info.Init(func_lib);
+  EXPECT_FALSE(ret.ok());
+}
+
+}  // namespace
+}  // namespace grappler
+}  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 5fd34ef..8c99598 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -156,7 +156,7 @@
     optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
         cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
   }
-  return Status::OK();
+  return InitializeCustomGraphOptimizers(optimizers);
 }
 
 Status MetaOptimizer::InitializeOptimizersByName(
@@ -180,6 +180,11 @@
       VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
     }
   }
+  return InitializeCustomGraphOptimizers(optimizers);
+}
+
+Status MetaOptimizer::InitializeCustomGraphOptimizers(
+    std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
     auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
         optimizer_config.name());
@@ -208,7 +213,7 @@
   }
 
   std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
-  if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
+  if (cfg_.optimizers().empty()) {
     TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
   } else {
     TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
@@ -326,10 +331,12 @@
 
 Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
                                GraphDef* optimized_graph) {
+  LOG(INFO) << "Starting optimization for grappler item: " << item.id;
   optimization_results_.clear();
 
   // 1. Optimize main graph
   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+  VLOG(1) << "Optimized main graph.";
 
   // 2. Optimize function library
   FunctionLibraryDefinition flib(OpRegistry::Global(),
@@ -393,7 +400,7 @@
     }
   }
 
-  VLOG(3) << "Optimized " << optimized_funcs.size()
+  VLOG(1) << "Optimized " << optimized_funcs.size()
           << " functions: " << str_util::Join(optimized_funcs, ", ");
 
   return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 151a54c..831c5e3 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -52,6 +52,9 @@
   // Initialize active optimizers from RewriterConfig optimizer names.
   Status InitializeOptimizersByName(
       std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+  // Initialize active optimizers from RewriterConfig.custom_optimizers.
+  Status InitializeCustomGraphOptimizers(
+      std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
 
   // Run optimization pass over a single GrapplerItem. Meta optimizer might run
   // multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 9a03c7d..e74e0f7 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -64,6 +64,13 @@
 
 REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
 
+class TestGraphOptimizer : public TestOptimizer {
+ public:
+  string name() const override { return "test_graph_optimizer"; }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+
 class MetaOptimizerTest : public GrapplerTest {};
 
 TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -83,6 +90,27 @@
   EXPECT_TRUE(TestOptimizer::IsOptimized());
 }
 
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  TestOptimizer::SetOptimized(false);
+  TestGraphOptimizer::SetOptimized(false);
+  RewriterConfig rewriter_config;
+  rewriter_config.add_optimizers("TestOptimizer");
+  auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+  customGraphOptimizer->set_name("TestGraphOptimizer");
+  rewriter_config.set_min_graph_nodes(-1);
+
+  MetaOptimizer optimizer(nullptr, rewriter_config);
+  GraphDef output;
+  const Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  EXPECT_TRUE(TestOptimizer::IsOptimized());
+  EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
 TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
   GrapplerItem item;
@@ -98,6 +126,24 @@
   TF_EXPECT_OK(status);
 }
 
+TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
+  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+  GrapplerItem item;
+  CHECK(fake_input.NextItem(&item));
+
+  RewriterConfig rewriter_config;
+  auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+  customGraphOptimizer->set_name("TestGraphOptimizer");
+  rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+  rewriter_config.set_min_graph_nodes(-1);
+
+  MetaOptimizer optimizer(nullptr, rewriter_config);
+  GraphDef output;
+  const Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
 TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
   using test::function::NDef;
 
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a2c363e..a428aea 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -304,21 +304,21 @@
 }
 
 GrapplerFunctionItem::GrapplerFunctionItem(
-    const string& func_name, const string& description,
-    const AttrValueMap& func_attr,
-    const std::vector<InputArgExpansion>& input_arg_expansions,
-    const std::vector<OutputArgExpansion>& output_arg_expansions,
-    const std::vector<string>& keep_nodes, const int graph_def_version,
-    bool is_stateful, GraphDef&& function_body)
-    : description_(description),
-      func_attr_(func_attr),
-      input_arg_expansions_(input_arg_expansions),
-      output_arg_expansions_(output_arg_expansions),
+    string func_name, string description, AttrValueMap func_attr,
+    std::vector<InputArgExpansion> input_arg_expansions,
+    std::vector<OutputArgExpansion> output_arg_expansions,
+    std::vector<string> keep_nodes, const int graph_def_version,
+    const bool is_stateful, GraphDef&& function_body)
+    : description_(std::move(description)),
+      func_attr_(std::move(func_attr)),
+      input_arg_expansions_(std::move(input_arg_expansions)),
+      output_arg_expansions_(std::move(output_arg_expansions)),
       is_stateful_(is_stateful) {
-  id = func_name;
-  keep_ops = keep_nodes;
-  // Swap the graph body.
-  graph.Swap(&function_body);
+  // Move assign GrapplerItem members.
+  keep_ops = std::move(keep_nodes);
+  id = std::move(func_name);
+  graph = std::move(function_body);
+
   graph.mutable_versions()->set_producer(graph_def_version);
   // Fill the feed nodes with input placeholders.
   for (const InputArgExpansion& input_arg : input_arg_expansions_) {
@@ -598,8 +598,8 @@
   *item = GrapplerFunctionItem(
       /*func_name=*/signature.name(), /*description=*/signature.description(),
       /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
-      inputs, outputs, keep_nodes, graph_def_version, is_stateful,
-      std::move(function_body));
+      std::move(inputs), std::move(outputs), std::move(keep_nodes),
+      graph_def_version, is_stateful, std::move(function_body));
   return Status::OK();
 }
 
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 61588ce..733caf3 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -136,13 +136,12 @@
 class GrapplerFunctionItem : public GrapplerItem {
  public:
   GrapplerFunctionItem() = default;
-  GrapplerFunctionItem(
-      const string& func_name, const string& description,
-      const AttrValueMap& func_attr,
-      const std::vector<InputArgExpansion>& input_arg_expansions,
-      const std::vector<OutputArgExpansion>& output_arg_expansions,
-      const std::vector<string>& keep_nodes, const int versions,
-      bool is_stateful, GraphDef&& function_body);
+  GrapplerFunctionItem(string func_name, string description,
+                       AttrValueMap func_attr,
+                       std::vector<InputArgExpansion> input_arg_expansions,
+                       std::vector<OutputArgExpansion> output_arg_expansions,
+                       std::vector<string> keep_nodes, int graph_def_version,
+                       bool is_stateful, GraphDef&& function_body);
 
   const string& description() const;
 
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4fb7aab..ceb9f5d 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -24,15 +24,16 @@
 namespace tensorflow {
 namespace grappler {
 
-// Compute modified strongly connected components:
+// Computes modified strongly connected components:
 // All nodes that are not part of a loop are assigned the special -1 id
 // All nodes that are part of at least one loop are assigned a positive
 // component id: if 2 nodes v and w are reachable from one another (i.e. if they
 // belong to the same scc), they'll be assigned the same id, otherwise they'll
-// be assigned distinct ids. Returns the number of distinct ids.
+// be assigned distinct ids. *num_components is set to the number of distinct
+// ids.
 void StronglyConnectedComponents(
     const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
-    int* num_ids);
+    int* num_components);
 
 // Returns the number of individual loops present in the graph, and populate the
 // 'loops' argument with the collection of loops (denoted by their loop ids) a
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 25063ac..94d3ab4 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -643,14 +643,7 @@
         ":split_v_op",
         ":strided_slice_op",
         ":tile_ops",
-    ] + if_mkl(
-        [
-            ":mkl_transpose_op",
-        ],
-        [
-            ":transpose_op",
-        ],
-    ) + [
+        ":transpose_op",
         ":unique_op",
         ":unpack_op",
         ":unravel_index_op",
@@ -893,24 +886,13 @@
     deps = ARRAY_DEPS,
 )
 
-if_mkl(
-    [tf_mkl_kernel_library(
-        name = "mkl_transpose_op",
-        srcs = [
-            "mkl_transpose_op.cc",
-            "transpose_op.cc",
-        ],
-        hdrs = ["transpose_op.h"],
-        deps = ARRAY_DEPS + mkl_deps(),
-    )],
-    [tf_kernel_library(
-        name = "transpose_op",
-        srcs = [
-            "transpose_op.cc",
-        ],
-        hdrs = ["transpose_op.h"],
-        deps = ARRAY_DEPS,
-    )],
+tf_kernel_library(
+    name = "transpose_op",
+    srcs = [
+        "transpose_op.cc",
+    ],
+    hdrs = ["transpose_op.h"],
+    deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]),
 )
 
 tf_kernel_library(
@@ -4522,6 +4504,25 @@
     deps = STRING_DEPS,
 )
 
+tf_cc_test(
+    name = "substr_op_test",
+    size = "small",
+    srcs = ["substr_op_test.cc"],
+    deps = [
+        ":substr_op",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/kernels:ops_testutil",
+        "//tensorflow/core/kernels:ops_util",
+    ],
+)
+
 tf_kernel_library(
     name = "as_string_op",
     prefix = "as_string_op",
@@ -5202,6 +5203,7 @@
         "fifo_queue.cc",
         "fifo_queue_op.cc",
         "fused_batch_norm_op.cc",
+        "listdiff_op.cc",
         "population_count_op.cc",
         "population_count_op.h",
         "winograd_transform.h",
@@ -6351,6 +6353,15 @@
     deps = NN_DEPS + mkl_deps() + [":cwise_op"],
 )
 
+tf_mkl_kernel_library(
+    name = "mkl_transpose_op",
+    srcs = [
+        "mkl_transpose_op.cc",
+    ],
+    hdrs = ["transpose_op.h"],
+    deps = ARRAY_DEPS + mkl_deps(),
+)
+
 # NOTE(lespeholt): This rule is deprecated, please use:
 # tensorflow/core/util/batch_util.h
 cc_library(
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021..4e8bfa0 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@
 
 tf_proto_library(
     name = "boosted_trees_proto",
-    srcs = ["boosted_trees.proto"],
+    srcs = [
+        "boosted_trees.proto",
+    ],
     cc_api_version = 2,
     visibility = ["//visibility:public"],
 )
@@ -87,9 +89,21 @@
 )
 
 tf_kernel_library(
+    name = "quantile_ops",
+    srcs = ["quantile_ops.cc"],
+    deps = [
+        "//tensorflow/core:boosted_trees_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+    ],
+)
+
+tf_kernel_library(
     name = "boosted_trees_ops",
     deps = [
         ":prediction_ops",
+        ":quantile_ops",
         ":resource_ops",
         ":stats_ops",
         ":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000..d184094
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+    boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+    boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+    boosted_trees::quantiles::WeightedQuantilesSummary<float,
+                                                       float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+                                      const int64 num_boundaries) {
+  std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+  // Uniquify elements as we may get dupes.
+  auto end_it = std::unique(boundaries.begin(), boundaries.end());
+  boundaries.resize(std::distance(boundaries.begin(), end_it));
+  return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+                                     const int64 num_quantiles) {
+  // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+  // will be returned.
+  std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+  CHECK_EQ(boundaries.size(), num_quantiles);
+  return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+                              const OpInputList& buckets_list) {
+  const auto& buckets = buckets_list[feature].flat<float>();
+  std::vector<float> buckets_vector(buckets.data(),
+                                    buckets.data() + buckets.size());
+  return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+    Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+    IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+  explicit BoostedTreesCreateQuantileStreamResourceOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Only create one, if one does not exist already. Report status for all
+    // other exceptions. If one already exists, it unrefs the new one.
+    // An epsilon value of zero could cause perfoamance issues and is therefore,
+    // disallowed.
+    const Tensor* epsilon_t;
+    OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+    float epsilon = epsilon_t->scalar<float>()();
+    OP_REQUIRES(
+        context, epsilon > 0,
+        errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+    const Tensor* num_streams_t;
+    OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+    int64 num_streams = num_streams_t->scalar<int64>()();
+
+    auto result =
+        new QuantileStreamResource(epsilon, max_elements_, num_streams);
+    auto status = CreateResource(context, HandleFromInput(context, 0), result);
+    if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+      OP_REQUIRES(context, false, status);
+    }
+  }
+
+ private:
+  // An upper bound on the number of entries that the summaries might have
+  // for a feature.
+  int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+    BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+  explicit BoostedTreesMakeQuantileSummariesOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    // Read float features list;
+    OpInputList float_features_list;
+    OP_REQUIRES_OK(
+        context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+    // Parse example weights and get batch size.
+    const Tensor* example_weights_t;
+    OP_REQUIRES_OK(context,
+                   context->input(kExampleWeightsName, &example_weights_t));
+    auto example_weights = example_weights_t->flat<float>();
+    const int64 batch_size = example_weights.size();
+    const Tensor* epsilon_t;
+    OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+    float epsilon = epsilon_t->scalar<float>()();
+
+    OpOutputList summaries_output_list;
+    OP_REQUIRES_OK(
+        context, context->output_list(kSummariesName, &summaries_output_list));
+
+    auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+      // Iterating features.
+      for (int64 index = begin; index < end; index++) {
+        const auto feature_values = float_features_list[index].flat<float>();
+        QuantileStream stream(epsilon, batch_size + 1);
+        // Run quantile summary generation.
+        for (int64 j = 0; j < batch_size; j++) {
+          stream.PushEntry(feature_values(j), example_weights(j));
+        }
+        stream.Finalize();
+        const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+        Tensor* output_t;
+        OP_REQUIRES_OK(
+            context,
+            summaries_output_list.allocate(
+                index,
+                TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+                &output_t));
+        auto output = output_t->matrix<float>();
+        for (auto row = 0; row < summary_entry_list.size(); row++) {
+          const auto& entry = summary_entry_list[row];
+          output(row, 0) = entry.value;
+          output(row, 1) = entry.weight;
+          output(row, 2) = entry.min_rank;
+          output(row, 3) = entry.max_rank;
+        }
+      }
+    };
+    // TODO(tanzheny): comment on the magic number.
+    const int64 kCostPerUnit = 500 * batch_size;
+    const DeviceBase::CpuWorkerThreads& worker_threads =
+        *context->device()->tensorflow_cpu_worker_threads();
+    Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+          kCostPerUnit, do_quantile_summary_gen);
+  }
+
+ private:
+  int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+    BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+  explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    ResourceHandle handle;
+    OP_REQUIRES_OK(context,
+                   HandleFromInput(context, kResourceHandleName, &handle));
+    QuantileStreamResource* stream_resource;
+    // Create a reference to the underlying resource using the handle.
+    OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+    // Remove the reference at the end of this scope.
+    mutex_lock l(*stream_resource->mutex());
+    core::ScopedUnref unref_me(stream_resource);
+
+    OpInputList summaries_list;
+    OP_REQUIRES_OK(context,
+                   context->input_list(kSummariesName, &summaries_list));
+    int32 num_streams = stream_resource->num_streams();
+    CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+    auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+      // Iterating all features.
+      for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+        const Tensor& summaries = summaries_list[feature_idx];
+        const auto summary_values = summaries.matrix<float>();
+        const auto& tensor_shape = summaries.shape();
+        const int64 entries_size = tensor_shape.dim_size(0);
+        CHECK_EQ(tensor_shape.dim_size(1), 4);
+        std::vector<QuantileSummaryEntry> summary_entries;
+        summary_entries.reserve(entries_size);
+        for (int64 i = 0; i < entries_size; i++) {
+          float value = summary_values(i, 0);
+          float weight = summary_values(i, 1);
+          float min_rank = summary_values(i, 2);
+          float max_rank = summary_values(i, 3);
+          QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+          summary_entries.push_back(entry);
+        }
+        stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+      }
+    };
+
+    // TODO(tanzheny): comment on the magic number.
+    const int64 kCostPerUnit = 500 * num_streams;
+    const DeviceBase::CpuWorkerThreads& worker_threads =
+        *context->device()->tensorflow_cpu_worker_threads();
+    Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+          kCostPerUnit, do_quantile_add_summary);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+    BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+  explicit BoostedTreesQuantileStreamResourceFlushOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context,
+                   context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    ResourceHandle handle;
+    OP_REQUIRES_OK(context,
+                   HandleFromInput(context, kResourceHandleName, &handle));
+    QuantileStreamResource* stream_resource;
+    // Create a reference to the underlying resource using the handle.
+    OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+    // Remove the reference at the end of this scope.
+    mutex_lock l(*stream_resource->mutex());
+    core::ScopedUnref unref_me(stream_resource);
+
+    const Tensor* num_buckets_t;
+    OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+    const int64 num_buckets = num_buckets_t->scalar<int64>()();
+    const int64 num_streams = stream_resource->num_streams();
+
+    auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+      // Iterating over all streams.
+      for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+        QuantileStream* stream = stream_resource->stream(stream_idx);
+        stream->Finalize();
+        stream_resource->set_boundaries(
+            generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+                                : GenerateBoundaries(*stream, num_buckets),
+            stream_idx);
+      }
+    };
+
+    // TODO(tanzheny): comment on the magic number.
+    const int64 kCostPerUnit = 500 * num_streams;
+    const DeviceBase::CpuWorkerThreads& worker_threads =
+        *context->device()->tensorflow_cpu_worker_threads();
+    Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+          kCostPerUnit, do_quantile_flush);
+
+    stream_resource->set_buckets_ready(true);
+  }
+
+ private:
+  bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+    BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+    : public OpKernel {
+ public:
+  explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+      OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    ResourceHandle handle;
+    OP_REQUIRES_OK(context,
+                   HandleFromInput(context, kResourceHandleName, &handle));
+    QuantileStreamResource* stream_resource;
+    // Create a reference to the underlying resource using the handle.
+    OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+    // Remove the reference at the end of this scope.
+    mutex_lock l(*stream_resource->mutex());
+    core::ScopedUnref unref_me(stream_resource);
+
+    const int64 num_streams = stream_resource->num_streams();
+    CHECK_EQ(num_features_, num_streams);
+    OpOutputList bucket_boundaries_list;
+    OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+                                                 &bucket_boundaries_list));
+
+    auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+      // Iterating over all streams.
+      for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+        const auto& boundaries = stream_resource->boundaries(stream_idx);
+        Tensor* bucket_boundaries_t = nullptr;
+        OP_REQUIRES_OK(context,
+                       bucket_boundaries_list.allocate(
+                           stream_idx, {static_cast<int64>(boundaries.size())},
+                           &bucket_boundaries_t));
+        auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+        memcpy(quantiles_flat, boundaries.data(),
+               sizeof(float) * boundaries.size());
+      }
+    };
+
+    // TODO(tanzheny): comment on the magic number.
+    const int64 kCostPerUnit = 500 * num_streams;
+    const DeviceBase::CpuWorkerThreads& worker_threads =
+        *context->device()->tensorflow_cpu_worker_threads();
+    Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+          kCostPerUnit, do_quantile_get_buckets);
+  }
+
+ private:
+  int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+    Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+        .Device(DEVICE_CPU),
+    BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+  explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+  }
+
+  void Compute(OpKernelContext* const context) override {
+    // Read float features list;
+    OpInputList float_features_list;
+    OP_REQUIRES_OK(
+        context, context->input_list(kFloatFeaturesName, &float_features_list));
+    OpInputList bucket_boundaries_list;
+    OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+                                                &bucket_boundaries_list));
+    OP_REQUIRES(context,
+                tensorflow::TensorShapeUtils::IsVector(
+                    bucket_boundaries_list[0].shape()),
+                errors::InvalidArgument(
+                    strings::Printf("Buckets should be flat vectors.")));
+    OpOutputList buckets_list;
+    OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+    auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+      // Iterating over all resources
+      for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+        const Tensor& values_tensor = float_features_list[feature_idx];
+        const int64 num_values = values_tensor.dim_size(0);
+
+        Tensor* output_t = nullptr;
+        OP_REQUIRES_OK(
+            context, buckets_list.allocate(
+                         feature_idx, TensorShape({num_values, 1}), &output_t));
+        auto output = output_t->matrix<int32>();
+
+        const std::vector<float>& bucket_boundaries_vector =
+            GetBuckets(feature_idx, bucket_boundaries_list);
+        CHECK(!bucket_boundaries_vector.empty())
+            << "Got empty buckets for feature " << feature_idx;
+        auto flat_values = values_tensor.flat<float>();
+        for (int64 instance = 0; instance < num_values; instance++) {
+          const float value = flat_values(instance);
+          auto bucket_iter =
+              std::lower_bound(bucket_boundaries_vector.begin(),
+                               bucket_boundaries_vector.end(), value);
+          if (bucket_iter == bucket_boundaries_vector.end()) {
+            --bucket_iter;
+          }
+          const int32 bucket = static_cast<int32>(
+              bucket_iter - bucket_boundaries_vector.begin());
+          // Bucket id.
+          output(instance, 0) = bucket;
+        }
+      }
+    };
+
+    // TODO(tanzheny): comment on the magic number.
+    const int64 kCostPerUnit = 500 * num_features_;
+    const DeviceBase::CpuWorkerThreads& worker_threads =
+        *context->device()->tensorflow_cpu_worker_threads();
+    Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+          kCostPerUnit, do_quantile_get_quantiles);
+  }
+
+ private:
+  int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+                        BoostedTreesBucketizeOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63..12d9473 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
 # Description:
-#   This directory contains common utilities used in boosted_trees.
+#   This directory contains common quantile utilities used in boosted_trees.
 package(
     default_visibility = ["//tensorflow:internal"],
 )
@@ -16,6 +16,7 @@
     name = "weighted_quantiles",
     srcs = [],
     hdrs = [
+        "quantile_stream_resource.h",
         "weighted_quantiles_buffer.h",
         "weighted_quantiles_stream.h",
         "weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@
     visibility = ["//visibility:public"],
     deps = [
         "//tensorflow/core:framework_headers_lib",
+        "//third_party/eigen3",
     ],
 )
 
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000..1c31724
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+    boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+  BoostedTreesQuantileStreamResource(const float epsilon,
+                                     const int64 max_elements,
+                                     const int64 num_streams)
+      : are_buckets_ready_(false),
+        epsilon_(epsilon),
+        num_streams_(num_streams),
+        max_elements_(max_elements) {
+          streams_.reserve(num_streams_);
+          boundaries_.reserve(num_streams_);
+          for (int64 idx = 0; idx < num_streams; ++idx) {
+            streams_.push_back(QuantileStream(epsilon, max_elements));
+            boundaries_.push_back(std::vector<float>());
+          }
+        }
+
+  string DebugString() override { return "QuantileStreamResource"; }
+
+  tensorflow::mutex* mutex() { return &mu_; }
+
+  QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+  const std::vector<float>& boundaries(const int64 index) {
+    return boundaries_[index];
+  }
+
+  void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+    boundaries_[index] = boundaries;
+  }
+
+  float epsilon() const { return epsilon_; }
+  int64 num_streams() const { return num_streams_; }
+
+  bool are_buckets_ready() const { return are_buckets_ready_; }
+  void set_buckets_ready(const bool are_buckets_ready) {
+    are_buckets_ready_ = are_buckets_ready;
+  }
+
+ private:
+  ~BoostedTreesQuantileStreamResource() override {}
+
+  // Mutex for the whole resource.
+  tensorflow::mutex mu_;
+
+  // Quantile streams.
+  std::vector<QuantileStream> streams_;
+
+  // Stores the boundaries. Same size as streams_.
+  std::vector<std::vector<float>> boundaries_;
+
+  // Whether boundaries are created. Initially boundaries are empty until
+  // set_boundaries are called.
+  bool are_buckets_ready_;
+
+  const float epsilon_;
+  const int64 num_streams_;
+  // An upper-bound for the number of elements.
+  int64 max_elements_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index a783689..390db8f 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -51,9 +51,11 @@
   //   dtype: The datatype of the gradients to be accumulated.
   //   shape: The shape of the accumulated gradients.
   //   name:  A name to use for the ConditionalAccumulator.
+  //   reduction_type: The reduction type, i.e., MEAN or SUM
   ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
-                         const string& name)
-      : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+                         const string& name, const string& reduction_type)
+      : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+                                                      reduction_type) {}
   ~ConditionalAccumulator() override{};
 
  protected:
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c5..292cf0c 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@
 ==============================================================================*/
 
 #include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
 
 ConditionalAccumulatorBase::ConditionalAccumulatorBase(
-    const DataType& dtype, const PartialTensorShape& shape, const string& name)
-    : dtype_(dtype), shape_(shape), name_(name) {
+    const DataType& dtype, const PartialTensorShape& shape, const string& name,
+    const string& reduction_type)
+    : dtype_(dtype),
+      shape_(shape),
+      name_(name),
+      reduction_type_(reduction_type) {
   counter_ = 0;
   current_global_step_ = 0;
 }
@@ -190,7 +195,9 @@
   current_global_step_++;
 
   // Average the accumulated gradient
-  DivideAccumGradByCounter(ctx);
+  if (reduction_type_ == "MEAN") {
+    DivideAccumGradByCounter(ctx);
+  }
 
   // Set output for accumulated gradient tensor
   bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index b7b7482..4a5ec6f 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -52,7 +52,7 @@
   //   name:  A name to use for the ConditionalAccumulator.
   ConditionalAccumulatorBase(const DataType& dtype,
                              const PartialTensorShape& shape,
-                             const string& name);
+                             const string& name, const string& reduction_type);
 
   typedef AsyncOpKernel::DoneCallback DoneCallback;
 
@@ -125,6 +125,7 @@
   const DataType dtype_;
   const PartialTensorShape shape_;
   const string name_;
+  const string reduction_type_;
   mutex mu_;
   int counter_ GUARDED_BY(mu_);
   int64 current_global_step_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 012a0dc..ca24d69 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -51,6 +51,8 @@
                                                 &accumulator_handle_, nullptr));
     OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
     OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("reduction_type", &reduction_type_));
   }
 
   void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@
   DataType dtype_;
   PartialTensorShape shape_;
   ContainerInfo cinfo_;
+  string reduction_type_;
 
  private:
   Status SetAccumulatorHandle(OpKernelContext* ctx)
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a..52ac51a 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@
   Creator GetCreator() const override {
     return [this](ConditionalAccumulatorBase** ret) {
       ConditionalAccumulator<Device, T>* accumulator =
-          new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+          new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+                                                reduction_type_);
       *ret = accumulator;
       return Status::OK();
     };
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 02e3655..b819c6f 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -19,6 +19,7 @@
 #define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
 
 #include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
 #include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
 
 namespace tensorflow {
@@ -28,6 +29,14 @@
 template <typename Device, typename T>
 struct CuboidConvolution;
 
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
 template <typename T>
@@ -42,6 +51,40 @@
   }
 };
 
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+  void operator()(const CPUDevice& d,
+                  typename TTypes<T, 5>::Tensor input_backward,
+                  typename TTypes<T, 5>::ConstTensor filter,
+                  typename TTypes<T, 5>::ConstTensor output_backward,
+                  int stride_planes, int stride_rows, int stride_cols) {
+    // Need to swap the order of plane/row/col strides when calling Eigen.
+    input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+        filter, output_backward,
+        input_backward.dimension(3),  // input_planes
+        input_backward.dimension(2),  // input_rows
+        input_backward.dimension(1),  // input_cols
+        stride_cols, stride_rows, stride_planes);
+  }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+  void operator()(const CPUDevice& d,
+                  typename TTypes<T, 5>::Tensor filter_backward,
+                  typename TTypes<T, 5>::ConstTensor input,
+                  typename TTypes<T, 5>::ConstTensor output_backward,
+                  int stride_planes, int stride_rows, int stride_cols) {
+    // Need to swap the order of plane/row/col strides when calling Eigen.
+    filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+        input, output_backward,
+        filter_backward.dimension(2),  // kernel_planes
+        filter_backward.dimension(1),  // kernel_rows
+        filter_backward.dimension(0),  // kernel_cols
+        stride_cols, stride_rows, stride_planes);
+  }
+};
+
 }  // namespace functor
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fc0a2f1..507720c 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@
 
 namespace tensorflow {
 
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+                                           int dim) const {
+  return (padding == VALID)
+             ? 0
+             : std::max<int>(
+                   0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+                                       (filter_size(dim) - 1) * dilation(dim) +
+                                       1 - input_size(dim)));
+}
+
 // The V2 version computes windowed output size with arbitrary dilation_rate,
 // while the original version only handles the cases where dilation_rates equal
 // to 1.
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d..9551959 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@
 
   // Input and output feature depth.
   int64 in_depth, out_depth;
+
+  // Convenience access methods for spatial dimensions properties.
+  int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+  int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+  int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+  int64 stride(int dim) const { return spatial_dims[dim].stride; }
+  int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+  // Compute padding for the given spatial dimension.
+  int SpatialPadding(const Padding& padding, int dim) const;
 };
 
 // Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9..d26b86c 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_slice.h"
 #include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
 #include "tensorflow/core/kernels/conv_ops_gpu.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 #include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
 
 #if GOOGLE_CUDA
 #include "tensorflow/core/platform/stream_executor.h"
 using stream_executor::dnn::DimIndex;
 #endif
 
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+            const int height, const int width, const int filter_p,
+            const int filter_h, const int filter_w, const int pad_pt,
+            const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+            const int pad_r, const int stride_p, const int stride_h,
+            const int stride_w, T* im_data) {
+  const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+  const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+  const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+  int p_pad = -pad_pt;
+  for (int p = 0; p < planes_col; ++p) {
+    int h_pad = -pad_t;
+    for (int h = 0; h < height_col; ++h) {
+      int w_pad = -pad_l;
+      for (int w = 0; w < width_col; ++w) {
+        T* im_patch_data =
+            im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+        for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+          for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+            for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+              if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+                  iw < width) {
+                for (int i = 0; i < depth; ++i) {
+                  im_patch_data[i] += col_data[i];
+                }
+              }
+              im_patch_data += depth;
+              col_data += depth;
+            }
+            // Jump over remaining number of depth.
+            im_patch_data += depth * (width - filter_w);
+          }
+          // Jump over remaining number of (depth * width).
+          im_patch_data += (depth * width) * (height - filter_h);
+        }
+        w_pad += stride_w;
+      }
+      h_pad += stride_h;
+    }
+    p_pad += stride_p;
+  }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+            const int height, const int width, const int filter_p,
+            const int filter_h, const int filter_w, const int pad_pt,
+            const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+            const int pad_r, const int stride_p, const int stride_h,
+            const int stride_w, T* col_data) {
+  const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+  const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+  const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+  int p_pad = -pad_pt;
+  for (int p = 0; p < planes_col; ++p) {
+    int h_pad = -pad_t;
+    for (int h = 0; h < height_col; ++h) {
+      int w_pad = -pad_l;
+      for (int w = 0; w < width_col; ++w) {
+        for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+          for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+            for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+              if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+                  iw < width) {
+                memcpy(col_data,
+                       input_data +
+                           (ip * height * width + ih * width + iw) * depth,
+                       sizeof(T) * depth);
+              } else {
+                // This should be simply padded with zero.
+                memset(col_data, 0, sizeof(T) * depth);
+              }
+              col_data += depth;
+            }
+          }
+        }
+        w_pad += stride_w;
+      }
+      h_pad += stride_h;
+    }
+    p_pad += stride_p;
+  }
+}
+
+}  // namespace
+
 namespace tensorflow {
 
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label)                                   \
-  const Tensor& out_backprop = context->input(2);                              \
-  OP_REQUIRES(                                                                 \
-      context, input_shape.dims() == 5,                                        \
-      errors::InvalidArgument(label, ": input must be 5-dimensional"));        \
-  OP_REQUIRES(                                                                 \
-      context, filter_shape.dims() == 5,                                       \
-      errors::InvalidArgument(label, ": filter must be 5-dimensional"));       \
-  OP_REQUIRES(                                                                 \
-      context, out_backprop.dims() == 5,                                       \
-      errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
-  const int64 batch = input_shape.dim_size(0);                                 \
-  OP_REQUIRES(                                                                 \
-      context, batch == out_backprop.dim_size(0),                              \
-      errors::InvalidArgument(                                                 \
-          label, ": input and out_backprop must have the same batch size"));   \
-  const std::array<int64, 3> input_size = {                                    \
-      {GetTensorDim(input_shape, data_format_, '0'),                           \
-       GetTensorDim(input_shape, data_format_, '1'),                           \
-       GetTensorDim(input_shape, data_format_, '2')}};                         \
-  const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C');         \
-  const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0),         \
-                                             filter_shape.dim_size(1),         \
-                                             filter_shape.dim_size(2)}};       \
-  const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2');     \
-  const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1');     \
-  const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0');   \
-  OP_REQUIRES(context, in_depth == filter_shape.dim_size(3),                   \
-              errors::InvalidArgument(                                         \
-                  label, ": input and filter must have the same depth"));      \
-  const int64 out_depth = filter_shape.dim_size(4);                            \
-  OP_REQUIRES(                                                                 \
-      context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'),     \
-      errors::InvalidArgument(                                                 \
-          label, ": filter and out_backprop must have the same out_depth"));   \
-  const std::array<int64, 3> dilations = {                                     \
-      {GetTensorDim(dilation_, data_format_, '0'),                             \
-       GetTensorDim(dilation_, data_format_, '1'),                             \
-       GetTensorDim(dilation_, data_format_, '2')}};                           \
-  const std::array<int64, 3> strides = {                                       \
-      {GetTensorDim(stride_, data_format_, '0'),                               \
-       GetTensorDim(stride_, data_format_, '1'),                               \
-       GetTensorDim(stride_, data_format_, '2')}};                             \
-  std::array<int64, 3> out, padding;                                           \
-  OP_REQUIRES_OK(                                                              \
-      context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides,  \
-                                 padding_, &out, &padding));                   \
-  OP_REQUIRES(context, output_planes == out[0],                                \
-              errors::InvalidArgument(                                         \
-                  label,                                                       \
-                  ": Number of planes of out_backprop doesn't match "          \
-                  "computed:  actual = ",                                      \
-                  output_planes, ", computed = ", out[0]));                    \
-  OP_REQUIRES(                                                                 \
-      context, output_rows == out[1],                                          \
-      errors::InvalidArgument(                                                 \
-          label, ": Number of rows of out_backprop doesn't match computed: ",  \
-          "actual = ", output_rows, ", computed = ", out[1]));                 \
-  OP_REQUIRES(                                                                 \
-      context, output_cols == out[2],                                          \
-      errors::InvalidArgument(                                                 \
-          label, ": Number of cols of out_backprop doesn't match computed: ",  \
-          "actual = ", output_cols, ", computed = ", out[2]));                 \
-  const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1;       \
-  const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1;           \
-  const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1;           \
-  const auto padded_out_planes = input_size[0] + filter_size[0] - 1;           \
-  const auto padded_out_rows = input_size[1] + filter_size[1] - 1;             \
-  const auto padded_out_cols = input_size[2] + filter_size[2] - 1;             \
-  const auto top_pad_planes = filter_size[0] - 1 - padding[0];                 \
-  const auto top_pad_rows = filter_size[1] - 1 - padding[1];                   \
-  const auto left_pad_cols = filter_size[2] - 1 - padding[2];                  \
-  const auto bottom_pad_planes =                                               \
-      padded_out_planes - expanded_out_planes - top_pad_planes;                \
-  const auto bottom_pad_rows =                                                 \
-      padded_out_rows - expanded_out_rows - top_pad_rows;                      \
-  const auto right_pad_cols =                                                  \
-      padded_out_cols - expanded_out_cols - left_pad_cols;                     \
-  VLOG(2) << "Conv3d: " << label                                               \
-          << ": expanded_out_planes = " << expanded_out_planes                 \
-          << ": expanded_out_rows = " << expanded_out_rows                     \
-          << ", expanded_out_cols = " << expanded_out_cols                     \
-          << ", padded_out_planes = " << padded_out_planes                     \
-          << ", padded_out_rows = " << padded_out_rows                         \
-          << ", padded_out_cols = " << padded_out_cols                         \
-          << ", top_pad_planes = " << top_pad_planes                           \
-          << ", top_pad_rows = " << top_pad_rows                               \
-          << ", left_pad_cols = " << left_pad_cols                             \
-          << ", bottom_pad_planes = " << bottom_pad_planes                     \
-          << ", bottom_pad_rows = " << bottom_pad_rows                         \
-          << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
 template <typename Device, class T>
 class Conv3DBackpropInputOp : public OpKernel {
  public:
@@ -192,6 +212,10 @@
   void Compute(OpKernelContext* context) override {
     const Tensor& filter = context->input(1);
     const TensorShape& filter_shape = filter.shape();
+
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
     TensorShape input_shape;
     if (takes_shape_) {
       const Tensor& input_sizes = context->input(0);
@@ -200,51 +224,25 @@
     } else {
       input_shape = context->input(0).shape();
     }
-    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
-    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
-        {0, 0},
-        {top_pad_planes, bottom_pad_planes},
-        {top_pad_rows, bottom_pad_rows},
-        {left_pad_cols, right_pad_cols},
-        {0, 0}};
+
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+                                "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+                                input_shape, filter_shape, out_backprop_shape,
+                                stride_, padding_, data_format_, &dims));
+
     Tensor* in_backprop;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input_shape, &in_backprop));
 
-    // Fill out a padded out_backprop.
-    TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
-                                  padded_out_cols, out_depth});
-    Tensor padded_output;
-    OP_REQUIRES_OK(context,
-                   context->allocate_temp(DataTypeToEnum<T>::v(),
-                                          padded_out_shape, &padded_output));
-    Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
-    Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
-                                                      strides[2], 1};
-    functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
-        eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
-    const Tensor& padded_output_cref = padded_output;
-
-    // Fill a new "reverted" filter. We need to transpose the in_depth and
-    // out_depth for the filter and reverse the planes, rows and cols.
-    TensorShape r_filter_shape(
-        {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
-    Tensor r_filter;
-    OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
-                                                   r_filter_shape, &r_filter));
-    Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
-    Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
-    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
-        filter_rev_dims, r_filter.tensor<T, 5>());
-    const Tensor& r_filter_cref = r_filter;
-
-    // Now we can call conv_3d directly.
-    functor::CuboidConvolution<Device, T>()(
-        context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
-        padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
-        1, BrainPadding2EigenPadding(VALID));
+    functor::CuboidConvolutionBackwardInput<Device, T>()(
+        context->eigen_device<Device>(),
+        in_backprop->tensor<T, 5>(),                     // input_backward
+        filter.tensor<T, 5>(),                           // filter
+        out_backprop.tensor<T, 5>(),                     // output_backward
+        static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
+        static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
+        static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
   }
 
  private:
@@ -253,21 +251,368 @@
   Padding padding_;
   TensorFormat data_format_;
   bool takes_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
 };
 
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+  // Limit the maximum size of allocated temporary buffer to
+  // kMaxTempAllocationOverhead times the size of the input tensors (input,
+  // filter, out_backprop). If the size of the temporary buffer exceeds this
+  // limit, fallback on Eigen implementation.
+  static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+  explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+      : OpKernel(context),
+        data_format_(FORMAT_NHWC),
+        takes_shape_(type_string().find("V2") != std::string::npos) {
+    // data_format is only available in V2.
+    if (takes_shape_) {
+      string data_format;
+      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+                  errors::InvalidArgument("Invalid data format"));
+      OP_REQUIRES(
+          context, data_format_ == FORMAT_NHWC,
+          errors::InvalidArgument(
+              "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+    }
+
+    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+    OP_REQUIRES(context, dilation_.size() == 5,
+                errors::InvalidArgument("Dilation rates field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+                 GetTensorDim(dilation_, data_format_, 'N') == 1),
+                errors::InvalidArgument(
+                    "Current implementation does not yet support "
+                    "dilation rates in the batch and depth dimensions."));
+
+    // TODO(yangzihao): Add CPU version of dilated conv 3D.
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+                 GetTensorDim(dilation_, data_format_, '1') == 1 &&
+                 GetTensorDim(dilation_, data_format_, '2') == 1),
+                errors::InvalidArgument(
+                    "Current CPU implementation does not yet support "
+                    "dilation rates larger than 1."));
+
+    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+    OP_REQUIRES(context, stride_.size() == 5,
+                errors::InvalidArgument("Sliding window strides field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+         GetTensorDim(stride_, data_format_, 'N') == 1),
+        errors::InvalidArgument("Current implementation does not yet support "
+                                "strides in the batch and depth dimensions."));
+    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& filter = context->input(1);
+    const TensorShape& filter_shape = filter.shape();
+
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
+    TensorShape input_shape;
+    if (takes_shape_) {
+      const Tensor& input_sizes = context->input(0);
+      // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+      OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+    } else {
+      input_shape = context->input(0).shape();
+    }
+
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+                                "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+                                input_shape, filter_shape, out_backprop_shape,
+                                stride_, padding_, data_format_, &dims));
+
+    Tensor* in_backprop;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, input_shape, &in_backprop));
+
+    int64 top_pad_planes, bottom_pad_planes;
+    int64 top_pad_rows, bottom_pad_rows;
+    int64 left_pad_cols, right_pad_cols;
+
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[0].input_size,
+                                dims.spatial_dims[0].filter_size,
+                                dims.spatial_dims[0].stride, padding_,
+                                &dims.spatial_dims[0].output_size,
+                                &top_pad_planes, &bottom_pad_planes));
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[1].input_size,
+                                dims.spatial_dims[1].filter_size,
+                                dims.spatial_dims[1].stride, padding_,
+                                &dims.spatial_dims[1].output_size,
+                                &top_pad_rows, &bottom_pad_rows));
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[2].input_size,
+                                dims.spatial_dims[2].filter_size,
+                                dims.spatial_dims[2].stride, padding_,
+                                &dims.spatial_dims[2].output_size,
+                                &left_pad_cols, &right_pad_cols));
+
+    // TODO(ezhulenev): Extract work size and shard estimation to shared
+    // functions in conv_grad_ops, and update 2d convolution backprop.
+
+    // The total dimension size of each kernel.
+    const int64 filter_total_size =
+        dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+        dims.spatial_dims[2].filter_size * dims.in_depth;
+
+    // The output image size is the spatial size of the output.
+    const int64 output_image_size = dims.spatial_dims[0].output_size *
+                                    dims.spatial_dims[1].output_size *
+                                    dims.spatial_dims[2].output_size;
+
+    const auto cache_sizes = Eigen::internal::CacheSizes();
+    const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+    // Use L3 cache size as target working set size.
+    const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+    // Calculate size of matrices involved in MatMul: C = A x B.
+    const int64 size_A = output_image_size * dims.out_depth;
+
+    const int64 size_B = filter_total_size * dims.out_depth;
+
+    const int64 size_C = output_image_size * filter_total_size;
+
+    const int64 work_unit_size = size_A + size_B + size_C;
+
+    auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+    // Use parallel tensor contractions if there is no batching.
+    //
+    // Compared to Conv2D code, this version is missing work size estimation. In
+    // benchmarks I didn't find a case when it's beneficial to run parallel
+    // contraction compared to sharding and matmuls.
+    const bool use_parallel_contraction = dims.batch_size == 1;
+
+    const size_t shard_size =
+        use_parallel_contraction
+            ? 1
+            : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+    // Total number of elements in all the tensors used by this kernel.
+    int64 total_tensor_elements = input_shape.num_elements() +
+                                  filter_shape.num_elements() +
+                                  out_backprop_shape.num_elements();
+
+    // Shape of the temporary workspace buffer.
+    TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+                                    static_cast<int64>(output_image_size),
+                                    static_cast<int64>(filter_total_size)};
+    int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+    // If the temporary allocation overhead is too large, fallback on Eigen
+    // implementation which requires much less memory.
+    int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+    if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+      VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+                 "col_buffer_overhead="
+              << col_buffer_overhead;
+
+      functor::CuboidConvolutionBackwardInput<Device, T>()(
+          context->eigen_device<Device>(),
+          in_backprop->tensor<T, 5>(),                     // input_backward
+          filter.tensor<T, 5>(),                           // filter
+          out_backprop.tensor<T, 5>(),                     // output_backward
+          static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
+          static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
+          static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
+
+      return;
+    }
+
+    Tensor col_buffer;
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(DataTypeToEnum<T>::value,
+                                          col_buffer_shape, &col_buffer));
+
+    // The input offset corresponding to a single input image.
+    const int64 input_offset = dims.spatial_dims[0].input_size *
+                               dims.spatial_dims[1].input_size *
+                               dims.spatial_dims[2].input_size * dims.in_depth;
+
+    // The output offset corresponding to a single output image.
+    const int64 output_offset =
+        dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+        dims.spatial_dims[2].output_size * dims.out_depth;
+
+    const T* filter_data = filter.template flat<T>().data();
+    T* col_buffer_data = col_buffer.template flat<T>().data();
+    const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+    auto in_backprop_flat = in_backprop->template flat<T>();
+    T* input_backprop_data = in_backprop_flat.data();
+    in_backprop_flat.device(context->eigen_device<Device>()) =
+        in_backprop_flat.constant(T(0));
+
+    if (use_parallel_contraction) {
+      typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+                               Eigen::Unaligned>
+          TensorMap;
+      typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+                               Eigen::Unaligned>
+          ConstTensorMap;
+
+      // Initialize contraction dims (we need to transpose 'B' below).
+      Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+      contract_dims[0].first = 1;
+      contract_dims[0].second = 1;
+
+      for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+        // Compute gradient into col_buffer.
+        TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+        ConstTensorMap A(out_backprop_data + output_offset * image_id,
+                         output_image_size, dims.out_depth);
+        ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+        C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+        Col2im<T>(col_buffer_data, dims.in_depth,
+                  // Input spatial dimensions.
+                  dims.spatial_dims[0].input_size,  // input planes
+                  dims.spatial_dims[1].input_size,  // input rows
+                  dims.spatial_dims[2].input_size,  // input cols
+                  // Filter spatial dimensions.
+                  dims.spatial_dims[0].filter_size,  // filter planes
+                  dims.spatial_dims[1].filter_size,  // filter rows
+                  dims.spatial_dims[2].filter_size,  // filter cols
+                  // Spatial padding.
+                  top_pad_planes, top_pad_rows, left_pad_cols,
+                  bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+                  // Spatial striding.
+                  dims.spatial_dims[0].stride,  // stride planes
+                  dims.spatial_dims[1].stride,  // stride rows
+                  dims.spatial_dims[2].stride,  // stride cols
+                  input_backprop_data);
+
+        input_backprop_data += input_offset;
+      }
+    } else {
+      typedef Eigen::Map<
+          Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+          MatrixMap;
+      typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+                                             Eigen::RowMajor>>
+          ConstMatrixMap;
+
+      for (int image_id = 0; image_id < dims.batch_size;
+           image_id += shard_size) {
+        const int shard_limit =
+            std::min(static_cast<int>(shard_size),
+                     static_cast<int>(dims.batch_size) - image_id);
+
+        auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+                      &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+                      &output_image_size, &filter_total_size,
+                      &input_backprop_data, &col_buffer_data,
+                      &out_backprop_data, &filter_data, &input_offset,
+                      &output_offset, &size_C](int64 start, int64 limit) {
+          for (int shard_id = start; shard_id < limit; ++shard_id) {
+            T* im2col_buf = col_buffer_data + shard_id * size_C;
+            T* input_data = input_backprop_data + shard_id * input_offset;
+            const T* out_data = out_backprop_data + shard_id * output_offset;
+
+            // Compute gradient into 'im2col_buf'.
+            MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+            ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+            ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+            C.noalias() = A * B.transpose();
+
+            Col2im<T>(im2col_buf, dims.in_depth,
+                      // Input spatial dimensions.
+                      dims.spatial_dims[0].input_size,  // input planes
+                      dims.spatial_dims[1].input_size,  // input rows
+                      dims.spatial_dims[2].input_size,  // input cols
+                      // Filter spatial dimensions.
+                      dims.spatial_dims[0].filter_size,  // filter planes
+                      dims.spatial_dims[1].filter_size,  // filter rows
+                      dims.spatial_dims[2].filter_size,  // filter cols
+                      // Spatial padding.
+                      top_pad_planes, top_pad_rows, left_pad_cols,
+                      bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+                      // Spatial striding.
+                      dims.spatial_dims[0].stride,  // stride planes
+                      dims.spatial_dims[1].stride,  // stride rows
+                      dims.spatial_dims[2].stride,  // stride cols
+                      input_data);
+          }
+        };
+        Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+              work_unit_size, shard);
+
+        input_backprop_data += input_offset * shard_limit;
+        out_backprop_data += output_offset * shard_limit;
+      }
+    }
+  }
+
+ private:
+  std::vector<int32> dilation_;
+  std::vector<int32> stride_;
+  Padding padding_;
+  TensorFormat data_format_;
+  bool takes_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
+};
+
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
 #define REGISTER_CPU_KERNEL(T)                                                 \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
-      Conv3DBackpropInputOp<CPUDevice, T>);                                    \
+      Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
   REGISTER_KERNEL_BUILDER(                                                     \
       Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      Conv3DBackpropInputOp<CPUDevice, T>);
+      Conv3DCustomBackpropInputOp<CPUDevice, T>);                              \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
+                              .Device(DEVICE_CPU)                              \
+                              .Label("custom")                                 \
+                              .TypeConstraint<T>("T"),                         \
+                          Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
+                              .Device(DEVICE_CPU)                              \
+                              .Label("custom")                                 \
+                              .TypeConstraint<T>("T"),                         \
+                          Conv3DCustomBackpropInputOp<CPUDevice, T>);          \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput")                          \
+                              .Device(DEVICE_CPU)                              \
+                              .Label("eigen_tensor")                           \
+                              .TypeConstraint<T>("T"),                         \
+                          Conv3DBackpropInputOp<CPUDevice, T>);                \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")                        \
+                              .Device(DEVICE_CPU)                              \
+                              .Label("eigen_tensor")                           \
+                              .TypeConstraint<T>("T"),                         \
+                          Conv3DBackpropInputOp<CPUDevice, T>);
+
 TF_CALL_half(REGISTER_CPU_KERNEL);
 TF_CALL_float(REGISTER_CPU_KERNEL);
 TF_CALL_double(REGISTER_CPU_KERNEL);
 #undef REGISTER_CPU_KERNEL
 
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
 template <typename Device, class T>
 class Conv3DBackpropFilterOp : public OpKernel {
  public:
@@ -323,8 +668,11 @@
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
     const TensorShape& input_shape = input.shape();
-    TensorShape filter_shape;
 
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
+    TensorShape filter_shape;
     if (takes_shape_) {
       const Tensor& filter_sizes = context->input(1);
       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@
       filter_shape = context->input(1).shape();
     }
 
-    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
-    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
-        {0, 0},
-        {top_pad_planes, bottom_pad_planes},
-        {top_pad_rows, bottom_pad_rows},
-        {left_pad_cols, right_pad_cols},
-        {0, 0}};
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context,
+                   ConvBackpropComputeDimensions(
+                       "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+                       input_shape, filter_shape, out_backprop_shape, stride_,
+                       padding_, data_format_, &dims));
+
     Tensor* filter_backprop;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,14 @@
       return;
     }
 
-    // For the backprop of the filter, we need to also transpose the
-    // out_backprop.
-    // The shape of backprop is
-    //   [batch, out_z, out_y, out_x, out_depth]
-    // And we need to change it to
-    //   [out_depth, out_x, out_y, out_z, batch]
-    Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
-    TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
-                                  padded_out_cols, batch});
-    Tensor padded_output;
-    OP_REQUIRES_OK(context,
-                   context->allocate_temp(DataTypeToEnum<T>::v(),
-                                          padded_out_shape, &padded_output));
-    Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
-                                                      strides[2], 1};
-    functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
-        eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
-    const Tensor& padded_output_cref = padded_output;
-
-    // For the backprop of the filter, we need to transpose the input.
-    // The shape of input is
-    //   [batch, in_z, in_y, in_x, in_depth]
-    // And we need to change it to
-    //   [in_z, in_y, in_x, batch, in_depth]
-    Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
-    TensorShape in_shuffle_shape(
-        {input_size[0], input_size[1], input_size[2], batch, in_depth});
-    Tensor in_shuffle;
-    OP_REQUIRES_OK(context,
-                   context->allocate_temp(DataTypeToEnum<T>::v(),
-                                          in_shuffle_shape, &in_shuffle));
-    // No need for reversing this time.
-    Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
-    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
-        no_reverse, in_shuffle.tensor<T, 5>());
-    const Tensor& in_shuffle_cref = in_shuffle;
-
-    // The output of the conv_3d would be
-    //   [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
-    // and we need to shuffle it back to
-    //   [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
-    // And we need to reverse the filter backprops.
-    // So we need to allocate (sigh) yet another piece of memory to hold the
-    // output.
-    TensorShape filter_shuffle_shape(
-        {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
-    Tensor filter_shuffle;
-    OP_REQUIRES_OK(
-        context, context->allocate_temp(DataTypeToEnum<T>::v(),
-                                        filter_shuffle_shape, &filter_shuffle));
-    functor::CuboidConvolution<Device, T>()(
-        context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
-        padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
-        1, BrainPadding2EigenPadding(VALID));
-
-    // Now copy the filter_backprop back to the destination.
-    Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
-    Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
-    const Tensor& filter_shuffle_cref = filter_shuffle;
-    functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
-        context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
-        filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+    functor::CuboidConvolutionBackwardFilter<Device, T>()(
+        context->eigen_device<Device>(),
+        filter_backprop->tensor<T, 5>(),                 // filter_backward
+        input.tensor<T, 5>(),                            // input
+        out_backprop.tensor<T, 5>(),                     // output_backward
+        static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
+        static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
+        static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
   }
 
  private:
@@ -421,8 +713,326 @@
   Padding padding_;
   TensorFormat data_format_;
   bool takes_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
 };
 
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+  // Limit the maximum size of allocated temporary buffer to
+  // kMaxTempAllocationOverhead times the size of the input tensors (input,
+  // filter, out_backprop). If the size of the temporary buffer exceeds this
+  // limit, fallback on Eigen implementation.
+  static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+  explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+      : OpKernel(context),
+        data_format_(FORMAT_NHWC),
+        takes_shape_(type_string().find("V2") != std::string::npos) {
+    // data_format is only available in V2.
+    if (takes_shape_) {
+      string data_format;
+      OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+      OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+                  errors::InvalidArgument("Invalid data format"));
+      OP_REQUIRES(
+          context, data_format_ == FORMAT_NHWC,
+          errors::InvalidArgument(
+              "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+    }
+
+    OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+    OP_REQUIRES(context, dilation_.size() == 5,
+                errors::InvalidArgument("Dilation rates field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+                 GetTensorDim(dilation_, data_format_, 'N') == 1),
+                errors::InvalidArgument(
+                    "Current implementation does not yet support "
+                    "dilation rates in the batch and depth dimensions."));
+
+    // TODO(yangzihao): Add CPU version of dilated conv 3D.
+    OP_REQUIRES(context,
+                (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+                 GetTensorDim(dilation_, data_format_, '1') == 1 &&
+                 GetTensorDim(dilation_, data_format_, '2') == 1),
+                errors::InvalidArgument(
+                    "Current CPU implementation does not yet support "
+                    "dilation rates larger than 1."));
+
+    OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+    OP_REQUIRES(context, stride_.size() == 5,
+                errors::InvalidArgument("Sliding window strides field must "
+                                        "specify 5 dimensions"));
+    OP_REQUIRES(
+        context,
+        (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+         GetTensorDim(stride_, data_format_, 'N') == 1),
+        errors::InvalidArgument("Current implementation does not yet support "
+                                "strides in the batch and depth dimensions."));
+    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input = context->input(0);
+    const TensorShape& input_shape = input.shape();
+
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
+    TensorShape filter_shape;
+    if (takes_shape_) {
+      const Tensor& filter_sizes = context->input(1);
+      OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+                                  filter_sizes.vec<int32>(), &filter_shape));
+    } else {
+      filter_shape = context->input(1).shape();
+    }
+
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context,
+                   ConvBackpropComputeDimensions(
+                       "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+                       input_shape, filter_shape, out_backprop_shape, stride_,
+                       padding_, data_format_, &dims));
+
+    Tensor* filter_backprop;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, filter_shape, &filter_backprop));
+
+    if (input_shape.num_elements() == 0) {
+      filter_backprop->template flat<T>().setZero();
+      return;
+    }
+
+    int64 top_pad_planes, bottom_pad_planes;
+    int64 top_pad_rows, bottom_pad_rows;
+    int64 left_pad_cols, right_pad_cols;
+
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[0].input_size,
+                                dims.spatial_dims[0].filter_size,
+                                dims.spatial_dims[0].stride, padding_,
+                                &dims.spatial_dims[0].output_size,
+                                &top_pad_planes, &bottom_pad_planes));
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[1].input_size,
+                                dims.spatial_dims[1].filter_size,
+                                dims.spatial_dims[1].stride, padding_,
+                                &dims.spatial_dims[1].output_size,
+                                &top_pad_rows, &bottom_pad_rows));
+    OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+                                dims.spatial_dims[2].input_size,
+                                dims.spatial_dims[2].filter_size,
+                                dims.spatial_dims[2].stride, padding_,
+                                &dims.spatial_dims[2].output_size,
+                                &left_pad_cols, &right_pad_cols));
+
+    // TODO(ezhulenev): Extract work size and shard estimation to shared
+    // functions in conv_grad_ops, and update 2d convolution backprop.
+
+    // The total dimension size of each kernel.
+    const int64 filter_total_size =
+        dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+        dims.spatial_dims[2].filter_size * dims.in_depth;
+    // The output image size is the spatial size of the output.
+    const int64 output_image_size = dims.spatial_dims[0].output_size *
+                                    dims.spatial_dims[1].output_size *
+                                    dims.spatial_dims[2].output_size;
+
+    // Shard 'batch' images (volumes) into 'shard_size' groups of images
+    // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+    // dividing the L3 cache size ('target_working_set_size') by the matmul size
+    // of an individual image ('work_unit_size').
+
+    const auto cache_sizes = Eigen::internal::CacheSizes();
+    const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+    // TODO(andydavis)
+    // *) Consider reducing 'target_working_set_size' if L3 is shared by
+    //    other concurrently running tensorflow ops.
+    const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+    const int64 size_A = output_image_size * filter_total_size;
+
+    const int64 size_B = output_image_size * dims.out_depth;
+
+    const int64 size_C = filter_total_size * dims.out_depth;
+
+    const int64 work_unit_size = size_A + size_B + size_C;
+
+    const size_t shard_size =
+        (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+    // Total number of elements in all the tensors used by this kernel.
+    int64 total_tensor_elements = input_shape.num_elements() +
+                                  filter_shape.num_elements() +
+                                  out_backprop_shape.num_elements();
+
+    // Shape of the temporary workspace buffer.
+    TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+                                    static_cast<int64>(output_image_size),
+                                    static_cast<int64>(filter_total_size)};
+    int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+    // If the temporary allocation overhead is too large, fallback on Eigen
+    // implementation which requires much less memory.
+    int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+    if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+      VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+                 "col_buffer_overhead="
+              << col_buffer_overhead;
+
+      functor::CuboidConvolutionBackwardFilter<Device, T>()(
+          context->eigen_device<Device>(),
+          filter_backprop->tensor<T, 5>(),                 // filter_backward
+          input.tensor<T, 5>(),                            // input
+          out_backprop.tensor<T, 5>(),                     // output_backward
+          static_cast<int>(dims.spatial_dims[0].stride),   // stride_planes
+          static_cast<int>(dims.spatial_dims[1].stride),   // stride_rows
+          static_cast<int>(dims.spatial_dims[2].stride));  // stride_cols
+
+      return;
+    }
+
+    Tensor col_buffer;
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(DataTypeToEnum<T>::value,
+                                          col_buffer_shape, &col_buffer));
+
+    // The input offset corresponding to a single input image.
+    const int64 input_offset = dims.spatial_dims[0].input_size *
+                               dims.spatial_dims[1].input_size *
+                               dims.spatial_dims[2].input_size * dims.in_depth;
+    // The output offset corresponding to a single output image.
+    const int64 output_offset =
+        dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+        dims.spatial_dims[2].output_size * dims.out_depth;
+
+    const T* input_data = input.template flat<T>().data();
+    T* col_buffer_data = col_buffer.template flat<T>().data();
+    const T* out_backprop_data = out_backprop.template flat<T>().data();
+    T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+    typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+                             Eigen::Unaligned>
+        TensorMap;
+    typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+                             Eigen::Unaligned>
+        ConstTensorMap;
+
+    TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+    C.setZero();
+
+    // Initialize contraction dims (we need to transpose 'A' below).
+    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+    contract_dims[0].first = 0;
+    contract_dims[0].second = 0;
+
+    auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+    for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+      const int shard_limit =
+          std::min(static_cast<int>(shard_size),
+                   static_cast<int>(dims.batch_size) - image_id);
+
+      auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+                    &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+                    &bottom_pad_rows, &right_pad_cols, &input_offset,
+                    &size_A](int64 start, int64 limit) {
+        for (int shard_id = start; shard_id < limit; ++shard_id) {
+          const T* input_data_shard = input_data + shard_id * input_offset;
+          T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+          // When we compute the gradient with respect to the filters, we need
+          // to do im2col to allow gemm-type computation.
+          Im2col<T>(input_data_shard, dims.in_depth,
+                    // Input spatial dimensions.
+                    dims.spatial_dims[0].input_size,  // input planes
+                    dims.spatial_dims[1].input_size,  // input rows
+                    dims.spatial_dims[2].input_size,  // input cols
+                    // Filter spatial dimensions.
+                    dims.spatial_dims[0].filter_size,  // filter planes
+                    dims.spatial_dims[1].filter_size,  // filter rows
+                    dims.spatial_dims[2].filter_size,  // filter cols
+                    // Spatial padding.
+                    top_pad_planes, top_pad_rows, left_pad_cols,
+                    bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+                    // Spatial striding.
+                    dims.spatial_dims[0].stride,  // stride planes
+                    dims.spatial_dims[1].stride,  // stride rows
+                    dims.spatial_dims[2].stride,  // stride cols
+                    col_data_shard);
+        }
+      };
+      Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+            size_A, shard);
+
+      ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+                       filter_total_size);
+      ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+                       dims.out_depth);
+
+      // Gradient with respect to filter.
+      C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+      input_data += input_offset * shard_limit;
+      out_backprop_data += output_offset * shard_limit;
+    }
+  }
+
+ private:
+  std::vector<int32> dilation_;
+  std::vector<int32> stride_;
+  Padding padding_;
+  TensorFormat data_format_;
+  bool takes_shape_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
+};
+
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
+#define REGISTER_CPU_KERNEL(T)                                                \
+  REGISTER_KERNEL_BUILDER(                                                    \
+      Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      Conv3DCustomBackpropFilterOp<CPUDevice, T>);                            \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
+                              .Device(DEVICE_CPU)                             \
+                              .TypeConstraint<T>("T"),                        \
+                          Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
+                              .Device(DEVICE_CPU)                             \
+                              .Label("custom")                                \
+                              .TypeConstraint<T>("T"),                        \
+                          Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
+                              .Device(DEVICE_CPU)                             \
+                              .Label("custom")                                \
+                              .TypeConstraint<T>("T"),                        \
+                          Conv3DCustomBackpropFilterOp<CPUDevice, T>);        \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter")                        \
+                              .Device(DEVICE_CPU)                             \
+                              .Label("eigen_tensor")                          \
+                              .TypeConstraint<T>("T"),                        \
+                          Conv3DBackpropFilterOp<CPUDevice, T>);              \
+  REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")                      \
+                              .Device(DEVICE_CPU)                             \
+                              .Label("eigen_tensor")                          \
+                              .TypeConstraint<T>("T"),                        \
+                          Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_CPU_KERNEL);
+TF_CALL_double(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
 #define REGISTER_CPU_KERNEL(T)                                                \
   REGISTER_KERNEL_BUILDER(                                                    \
       Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
@@ -431,9 +1041,8 @@
                               .Device(DEVICE_CPU)                             \
                               .TypeConstraint<T>("T"),                        \
                           Conv3DBackpropFilterOp<CPUDevice, T>);
+
 TF_CALL_half(REGISTER_CPU_KERNEL);
-TF_CALL_float(REGISTER_CPU_KERNEL);
-TF_CALL_double(REGISTER_CPU_KERNEL);
 #undef REGISTER_CPU_KERNEL
 
 // GPU definitions of both ops.
@@ -523,6 +1132,10 @@
   void Compute(OpKernelContext* context) override {
     const Tensor& filter = context->input(1);
     const TensorShape& filter_shape = filter.shape();
+
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
     TensorShape input_shape;
     if (takes_shape_) {
       const Tensor& input_sizes = context->input(0);
@@ -531,7 +1144,14 @@
     } else {
       input_shape = context->input(0).shape();
     }
-    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context,
+                   ConvBackpropComputeDimensionsV2(
+                       "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+                       input_shape, filter_shape, out_backprop_shape, dilation_,
+                       stride_, padding_, data_format_, &dims));
+
     Tensor* in_backprop;
     OP_REQUIRES_OK(context,
                    context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1159,15 @@
     auto* stream = context->op_device_context()->stream();
     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
 
-    if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
-        dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
-        stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+    if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+        dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+        dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+        dims.stride(1) == 1 && dims.stride(2) == 1 &&
         data_format_ == FORMAT_NHWC) {
-      const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
-      const uint64 k = out_depth;
-      const uint64 n = in_depth;
+      const uint64 m = dims.batch_size * dims.input_size(0) *
+                       dims.input_size(1) * dims.input_size(2);
+      const uint64 k = dims.out_depth;
+      const uint64 n = dims.in_depth;
 
       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                   out_backprop.template flat<T>().size());
@@ -567,13 +1189,14 @@
                                             ", n=", n, ", k=", k));
       }
       return;
-    } else if (filter_size[0] == input_size[0] &&
-               filter_size[1] == input_size[1] &&
-               filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
-               data_format_ == FORMAT_NHWC) {
-      const uint64 m = batch;
-      const uint64 k = out_depth;
-      const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+    } else if (dims.filter_size(0) == dims.input_size(0) &&
+               dims.filter_size(1) == dims.input_size(1) &&
+               dims.filter_size(2) == dims.input_size(2) &&
+               padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+      const uint64 m = dims.batch_size;
+      const uint64 k = dims.out_depth;
+      const uint64 n = dims.input_size(0) * dims.input_size(1) *
+                       dims.input_size(2) * dims.in_depth;
 
       auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
                                   out_backprop.template flat<T>().size());
@@ -597,65 +1220,59 @@
       return;
     }
 
-    int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
-    if (padding_ == Padding::SAME) {
-      padding_planes = std::max<int>(
-          0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
-      padding_cols = std::max<int>(
-          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
-      padding_rows = std::max<int>(
-          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
-    }
+    int padding_planes = dims.SpatialPadding(padding_, 0);
+    int padding_rows = dims.SpatialPadding(padding_, 1);
+    int padding_cols = dims.SpatialPadding(padding_, 2);
+    const bool planes_odd = (padding_planes % 2 != 0);
     const bool rows_odd = (padding_rows % 2 != 0);
     const bool cols_odd = (padding_cols % 2 != 0);
-    const bool planes_odd = (padding_planes % 2 != 0);
 
     TensorShape compatible_input_shape;
     if (rows_odd || cols_odd || planes_odd) {
       // cuDNN only supports the same amount of padding on both sides.
       compatible_input_shape = {
-          batch,
-          in_depth,
-          input_size[0] + planes_odd,
-          input_size[1] + rows_odd,
-          input_size[2] + cols_odd,
+          dims.batch_size,
+          dims.in_depth,
+          dims.input_size(0) + planes_odd,
+          dims.input_size(1) + rows_odd,
+          dims.input_size(2) + cols_odd,
       };
     } else {
-      compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
-                                input_size[2]};
+      compatible_input_shape = {dims.batch_size, dims.in_depth,
+                                dims.input_size(0), dims.input_size(1),
+                                dims.input_size(2)};
     }
 
     CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
         << "Negative paddings: (" << padding_rows << ", " << padding_cols
         << ", " << padding_planes << ")";
     se::dnn::BatchDescriptor input_desc(3);
-    input_desc.set_count(batch)
+    input_desc.set_count(dims.batch_size)
         .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
         .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
         .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
-        .set_feature_map_count(in_depth)
+        .set_feature_map_count(dims.in_depth)
         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
     se::dnn::BatchDescriptor output_desc(3);
-    output_desc.set_count(batch)
-        .set_spatial_dim(DimIndex::X, output_cols)
-        .set_spatial_dim(DimIndex::Y, output_rows)
-        .set_spatial_dim(DimIndex::Z, output_planes)
-        .set_feature_map_count(out_depth)
+    output_desc.set_count(dims.batch_size)
+        .set_spatial_dim(DimIndex::X, dims.output_size(2))
+        .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+        .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+        .set_feature_map_count(dims.out_depth)
         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
     se::dnn::FilterDescriptor filter_desc(3);
-    filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
-        .set_spatial_dim(DimIndex::Y, filter_size[1])
-        .set_spatial_dim(DimIndex::Z, filter_size[0])
-        .set_input_feature_map_count(in_depth)
-        .set_output_feature_map_count(out_depth);
+    filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+        .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+        .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+        .set_input_feature_map_count(dims.in_depth)
+        .set_output_feature_map_count(dims.out_depth);
     se::dnn::ConvolutionDescriptor conv_desc(3);
-    conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
-        .set_dilation_rate(DimIndex::Y, dilations[1])
-        .set_dilation_rate(DimIndex::Z, dilations[0])
-        .set_filter_stride(DimIndex::X, strides[2])
-        .set_filter_stride(DimIndex::Y, strides[1])
-        .set_filter_stride(DimIndex::Z, strides[0])
+    conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+        .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+        .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+        .set_filter_stride(DimIndex::X, dims.stride(2))
+        .set_filter_stride(DimIndex::Y, dims.stride(1))
+        .set_filter_stride(DimIndex::Z, dims.stride(0))
         .set_zero_padding(DimIndex::X, padding_cols / 2)
         .set_zero_padding(DimIndex::Y, padding_rows / 2)
         .set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,10 +1281,11 @@
     Tensor transformed_filter;
     OP_REQUIRES_OK(
         context,
-        context->allocate_temp(DataTypeToEnum<T>::value,
-                               TensorShape({out_depth, in_depth, filter_size[0],
-                                            filter_size[1], filter_size[2]}),
-                               &transformed_filter));
+        context->allocate_temp(
+            DataTypeToEnum<T>::value,
+            TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+                         dims.filter_size(1), dims.filter_size(2)}),
+            &transformed_filter));
     functor::TransformFilter<GPUDevice, T, int, 5>()(
         context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
         To32Bit(transformed_filter.tensor<T, 5>()));
@@ -675,9 +1293,10 @@
     // Shape: batch, filters, z, y, x.
     Tensor transformed_out_backprop;
     if (data_format_ == FORMAT_NHWC) {
-      TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
-                                output_cols};
-      if (out_depth > 1) {
+      TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+                                dims.output_size(0), dims.output_size(1),
+                                dims.output_size(2)};
+      if (dims.out_depth > 1) {
         OP_REQUIRES_OK(context, context->allocate_temp(
                                     DataTypeToEnum<T>::value, nchw_shape,
                                     &transformed_out_backprop));
@@ -713,14 +1332,14 @@
     const int device_id = stream->parent()->device_ordinal();
     DataType dtype = context->input(0).dtype();
     const ConvParameters conv_parameters = {
-        batch,
-        in_depth,
-        {{input_size[0], input_size[1], input_size[2]}},
+        dims.batch_size,
+        dims.in_depth,
+        {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
         FORMAT_NCHW,
-        out_depth,
-        {{filter_size[0], filter_size[1], filter_size[2]}},
-        {{dilations[0], dilations[1], dilations[2]}},
-        {{strides[0], strides[1], strides[2]}},
+        dims.out_depth,
+        {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+        {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+        {{dims.stride(0), dims.stride(1), dims.stride(2)}},
         {{padding_planes, padding_rows, padding_cols}},
         dtype,
         device_id,
@@ -799,10 +1418,11 @@
     if (rows_odd || cols_odd || planes_odd) {
       Tensor in_backprop_remove_padding;
       OP_REQUIRES_OK(context,
-                     context->allocate_temp(DataTypeToEnum<T>::value,
-                                            {batch, in_depth, input_size[0],
-                                             input_size[1], input_size[2]},
-                                            &in_backprop_remove_padding));
+                     context->allocate_temp(
+                         DataTypeToEnum<T>::value,
+                         {dims.batch_size, dims.in_depth, dims.input_size(0),
+                          dims.input_size(1), dims.input_size(2)},
+                         &in_backprop_remove_padding));
 
       // Remove the padding for odd spatial dimensions.
       functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1516,10 @@
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
     const TensorShape& input_shape = input.shape();
+
+    const Tensor& out_backprop = context->input(2);
+    const TensorShape& out_backprop_shape = out_backprop.shape();
+
     TensorShape filter_shape;
     if (takes_shape_) {
       const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1529,12 @@
       filter_shape = context->input(1).shape();
     }
 
-    EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+    ConvBackpropDimensions dims;
+    OP_REQUIRES_OK(context,
+                   ConvBackpropComputeDimensionsV2(
+                       "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+                       input_shape, filter_shape, out_backprop_shape, dilation_,
+                       stride_, padding_, data_format_, &dims));
 
     Tensor* filter_backprop;
     OP_REQUIRES_OK(context,
@@ -914,13 +1543,15 @@
     auto* stream = context->op_device_context()->stream();
     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
 
-    if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
-        dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
-        strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+    if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+        dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+        dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+        dims.stride(1) == 1 && dims.stride(0) == 1 &&
         data_format_ == FORMAT_NHWC) {
-      const uint64 m = in_depth;
-      const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
-      const uint64 n = out_depth;
+      const uint64 m = dims.in_depth;
+      const uint64 k = dims.batch_size * dims.input_size(1) *
+                       dims.input_size(2) * dims.input_size(0);
+      const uint64 n = dims.out_depth;
 
       // The shape of output backprop is
       //   [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1582,14 @@
                                             ", n=", n, ", k=", k));
       }
       return;
-    } else if (filter_size[0] == input_size[0] &&
-               filter_size[1] == input_size[1] &&
-               filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
-               data_format_ == FORMAT_NHWC) {
-      const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
-      const uint64 k = batch;
-      const uint64 n = out_depth;
+    } else if (dims.filter_size(0) == dims.input_size(0) &&
+               dims.filter_size(1) == dims.input_size(1) &&
+               dims.filter_size(2) == dims.input_size(2) &&
+               padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+      const uint64 m = dims.input_size(0) * dims.input_size(1) *
+                       dims.input_size(2) * dims.in_depth;
+      const uint64 k = dims.batch_size;
+      const uint64 n = dims.out_depth;
 
       auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
                                   input.template flat<T>().size());
@@ -979,30 +1611,24 @@
       return;
     }
 
-    int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
-    if (padding_ == Padding::SAME) {
-      padding_planes = std::max<int>(
-          0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
-      padding_cols = std::max<int>(
-          0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
-      padding_rows = std::max<int>(
-          0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
-    }
-    bool rows_odd = (padding_rows % 2 != 0);
-    bool cols_odd = (padding_cols % 2 != 0);
-    bool planes_odd = (padding_planes % 2 != 0);
+    int padding_planes = dims.SpatialPadding(padding_, 0);
+    int padding_rows = dims.SpatialPadding(padding_, 1);
+    int padding_cols = dims.SpatialPadding(padding_, 2);
+    const bool planes_odd = (padding_planes % 2 != 0);
+    const bool rows_odd = (padding_rows % 2 != 0);
+    const bool cols_odd = (padding_cols % 2 != 0);
 
     Tensor compatible_input;
     if (rows_odd || cols_odd || planes_odd) {
-      OP_REQUIRES_OK(context, context->allocate_temp(
-                                  DataTypeToEnum<T>::value,
-                                  ShapeFromFormat(data_format_, batch,
-                                                  {{input_size[0] + planes_odd,
-                                                    input_size[1] + rows_odd,
-                                                    input_size[2] + cols_odd}},
-                                                  in_depth),
-                                  &compatible_input));
+      OP_REQUIRES_OK(context,
+                     context->allocate_temp(
+                         DataTypeToEnum<T>::value,
+                         ShapeFromFormat(data_format_, dims.batch_size,
+                                         {{dims.input_size(0) + planes_odd,
+                                           dims.input_size(1) + rows_odd,
+                                           dims.input_size(2) + cols_odd}},
+                                         dims.in_depth),
+                         &compatible_input));
       functor::PadInput<GPUDevice, T, int, 5>()(
           context->template eigen_device<GPUDevice>(),
           To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1642,35 @@
         << "Negative paddings: (" << padding_rows << ", " << padding_cols
         << ", " << padding_planes << ")";
     se::dnn::BatchDescriptor input_desc(3);
-    input_desc.set_count(batch)
+    input_desc.set_count(dims.batch_size)
         .set_spatial_dim(DimIndex::X,
                          GetTensorDim(compatible_input, data_format_, '2'))
         .set_spatial_dim(DimIndex::Y,
                          GetTensorDim(compatible_input, data_format_, '1'))
         .set_spatial_dim(DimIndex::Z,
                          GetTensorDim(compatible_input, data_format_, '0'))
-        .set_feature_map_count(in_depth)
+        .set_feature_map_count(dims.in_depth)
         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
     se::dnn::BatchDescriptor output_desc(3);
-    output_desc.set_count(batch)
-        .set_spatial_dim(DimIndex::X, output_cols)
-        .set_spatial_dim(DimIndex::Y, output_rows)
-        .set_spatial_dim(DimIndex::Z, output_planes)
-        .set_feature_map_count(out_depth)
+    output_desc.set_count(dims.batch_size)
+        .set_spatial_dim(DimIndex::X, dims.output_size(2))
+        .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+        .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+        .set_feature_map_count(dims.out_depth)
         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
     se::dnn::FilterDescriptor filter_desc(3);
-    filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
-        .set_spatial_dim(DimIndex::Y, filter_size[1])
-        .set_spatial_dim(DimIndex::Z, filter_size[0])
-        .set_input_feature_map_count(in_depth)
-        .set_output_feature_map_count(out_depth);
+    filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+        .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+        .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+        .set_input_feature_map_count(dims.in_depth)
+        .set_output_feature_map_count(dims.out_depth);
     se::dnn::ConvolutionDescriptor conv_desc(3);
-    conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
-        .set_dilation_rate(DimIndex::Y, dilations[1])
-        .set_dilation_rate(DimIndex::Z, dilations[0])
-        .set_filter_stride(DimIndex::X, strides[2])
-        .set_filter_stride(DimIndex::Y, strides[1])
-        .set_filter_stride(DimIndex::Z, strides[0])
+    conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+        .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+        .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+        .set_filter_stride(DimIndex::X, dims.stride(2))
+        .set_filter_stride(DimIndex::Y, dims.stride(1))
+        .set_filter_stride(DimIndex::Z, dims.stride(0))
         .set_zero_padding(DimIndex::X, padding_cols / 2)
         .set_zero_padding(DimIndex::Y, padding_rows / 2)
         .set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1678,21 @@
     Tensor pre_transformed_filter_backprop;
     OP_REQUIRES_OK(
         context,
-        context->allocate_temp(DataTypeToEnum<T>::value,
-                               TensorShape({out_depth, in_depth, filter_size[0],
-                                            filter_size[1], filter_size[2]}),
-                               &pre_transformed_filter_backprop));
+        context->allocate_temp(
+            DataTypeToEnum<T>::value,
+            TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+                         dims.filter_size(1), dims.filter_size(2)}),
+            &pre_transformed_filter_backprop));
 
     Tensor transformed_out_backprop;
     if (data_format_ == FORMAT_NHWC) {
-      TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
-                                output_cols};
+      TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+                                dims.output_size(0), dims.output_size(1),
+                                dims.output_size(2)};
       OP_REQUIRES_OK(
           context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
                                           &transformed_out_backprop));
-      if (out_depth > 1) {
+      if (dims.out_depth > 1) {
         functor::NHWCToNCHW<GPUDevice, T, 5>()(
             context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
             transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1704,10 @@
     }
     Tensor transformed_input;
     if (data_format_ == FORMAT_NHWC) {
-      TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
-                                compatible_input.dim_size(2),
-                                compatible_input.dim_size(3)};
-      if (in_depth > 1) {
+      TensorShape nchw_shape = {
+          dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+          compatible_input.dim_size(2), compatible_input.dim_size(3)};
+      if (dims.in_depth > 1) {
         OP_REQUIRES_OK(context,
                        context->allocate_temp(DataTypeToEnum<T>::value,
                                               nchw_shape, &transformed_input));
@@ -1110,14 +1738,14 @@
     const int device_id = stream->parent()->device_ordinal();
     DataType dtype = input.dtype();
     const ConvParameters conv_parameters = {
-        batch,
-        in_depth,
-        {{input_size[0], input_size[1], input_size[2]}},
+        dims.batch_size,
+        dims.in_depth,
+        {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
         FORMAT_NCHW,
-        out_depth,
-        {{filter_size[0], filter_size[1], filter_size[2]}},
-        {{dilations[0], dilations[1], dilations[2]}},
-        {{strides[0], strides[1], strides[2]}},
+        dims.out_depth,
+        {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+        {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+        {{dims.stride(0), dims.stride(1), dims.stride(2)}},
         {{padding_planes, padding_rows, padding_cols}},
         dtype,
         device_id,
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index e7b3d0c..b3c3590 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -51,6 +51,7 @@
     hdrs = ["captured_function.h"],
     deps = [
         ":dataset",
+        ":single_threaded_executor",
         "//tensorflow/core:core_cpu_internal",
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
@@ -61,6 +62,42 @@
 )
 
 cc_library(
+    name = "single_threaded_executor",
+    srcs = ["single_threaded_executor.cc"],
+    hdrs = ["single_threaded_executor.h"],
+    deps = [
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:lib",
+    ],
+    alwayslink = 1,
+)
+
+tf_cc_test(
+    name = "single_threaded_executor_test",
+    srcs = ["single_threaded_executor_test.cc"],
+    deps = [
+        ":single_threaded_executor",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/kernels:array",
+        "//tensorflow/core/kernels:control_flow_ops",
+        "//tensorflow/core/kernels:function_ops",
+        "//tensorflow/core/kernels:math",
+        "//tensorflow/core/kernels:random_ops",
+        "//tensorflow/core/kernels:state",
+    ],
+)
+
+cc_library(
     name = "window_dataset",
     srcs = ["window_dataset.cc"],
     hdrs = ["window_dataset.h"],
@@ -638,6 +675,19 @@
 )
 
 tf_kernel_library(
+    name = "model_dataset_op",
+    srcs = ["model_dataset_op.cc"],
+    deps = [
+        ":dataset",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:dataset_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+tf_kernel_library(
     name = "dataset_ops",
     srcs = ["dataset_ops.cc"],
     deps = [
@@ -671,6 +721,7 @@
         ":map_and_batch_dataset_op",
         ":map_dataset_op",
         ":map_defun_op",
+        ":model_dataset_op",
         ":optimize_dataset_op",
         ":optional_ops",
         ":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index f9b5353..887b8c8 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -117,6 +117,7 @@
           : DatasetIterator<Dataset>(params) {}
 
       Status Initialize(IteratorContext* ctx) override {
+        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
       }
 
@@ -241,5 +242,5 @@
                         BatchDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 6ca0bcd..34c6c86 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level description of
@@ -69,7 +69,7 @@
     std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
-          new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+          new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
     }
 
     const DataTypeVector& output_dtypes() const override {
@@ -553,7 +553,7 @@
     std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new MemoryIterator(
-          {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+          {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
     }
 
     const DataTypeVector& output_dtypes() const override {
@@ -891,5 +891,5 @@
                         CacheDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index abdf6ee..31c8f5c 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,18 +17,29 @@
 #include <utility>
 
 #include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/lib/gtl/optional.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/platform/notification.h"
 
 namespace tensorflow {
+namespace data {
 
 /* static */
 Status CapturedFunction::Create(
     const NameAttrList& func, std::vector<Tensor> captured_inputs,
     std::unique_ptr<CapturedFunction>* out_function) {
-  out_function->reset(new CapturedFunction(func, std::move(captured_inputs)));
+  return Create(func, std::move(captured_inputs), true, out_function);
+}
+
+/* static */
+Status CapturedFunction::Create(
+    const NameAttrList& func, std::vector<Tensor> captured_inputs,
+    bool use_inter_op_parallelism,
+    std::unique_ptr<CapturedFunction>* out_function) {
+  out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
+                                           use_inter_op_parallelism));
   return Status::OK();
 }
 
@@ -272,6 +283,9 @@
     inst_opts.overlay_lib = ctx->function_library().get();
     inst_opts.state_handle = std::to_string(random::New64());
     inst_opts.create_kernels_eagerly = true;
+    if (!use_inter_op_parallelism_) {
+      inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
+    }
     Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
                                   inst_opts, &f_handle_));
     TF_RETURN_IF_ERROR(s);
@@ -345,7 +359,8 @@
 void CapturedFunction::RunAsync(IteratorContext* ctx,
                                 std::vector<Tensor>&& args,
                                 std::vector<Tensor>* rets,
-                                FunctionLibraryRuntime::DoneCallback done) {
+                                FunctionLibraryRuntime::DoneCallback done,
+                                const string& prefix) {
   // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
   // be deleted before `done` is called. Take care not to capture `ctx` in any
   // code that may execute asynchronously in this function.
@@ -378,30 +393,61 @@
   // will be required to plumb it through the `IteratorContext`.
   auto c_mgr = new CancellationManager;
   f_opts.cancellation_manager = c_mgr;
+  StepStats* stats = nullptr;
+  StepStatsCollector* stats_collector = nullptr;
+  std::shared_ptr<model::Node> node;
+  if (ctx->model()) {
+    node = ctx->model()->LookupNode(prefix);
+    if (node) {
+      // TODO(b/114104975): Use something light-weight here.
+      stats = new StepStats();
+      stats_collector = new StepStatsCollector(stats);
+    }
+  }
+  f_opts.stats_collector = stats_collector;
 
-  tf_shared_lock l(mu_);
-  ctx->lib()->Run(f_opts, handle, frame,
-                  std::bind(
-                      [rets, step_container, c_mgr, frame](
-                          FunctionLibraryRuntime::DoneCallback done,
-                          // Begin unbound arguments.
-                          Status s) {
-                        delete step_container;
-                        delete c_mgr;
-                        if (s.ok()) {
-                          s = frame->ConsumeRetvals(rets);
-                        }
-                        delete frame;
-                        done(s);
-                      },
-                      std::move(done), std::placeholders::_1));
+  auto callback = std::bind(
+      [rets, step_container, c_mgr, frame, stats, stats_collector, node](
+          FunctionLibraryRuntime::DoneCallback done,
+          // Begin unbound arguments.
+          Status s) {
+        delete step_container;
+        delete c_mgr;
+        if (s.ok()) {
+          s = frame->ConsumeRetvals(rets);
+        }
+        delete frame;
+        if (node) {
+          int64 delta = 0;
+          stats_collector->Finalize();
+          for (auto dev_stats : stats->dev_stats()) {
+            for (auto node_stats : dev_stats.node_stats()) {
+              delta += node_stats.all_end_rel_nanos();
+            }
+          }
+          delete stats_collector;
+          delete stats;
+          node->add_processing_time(delta);
+          node->start_work();
+        }
+        done(s);
+        if (node) {
+          node->stop_work();
+        }
+      },
+      std::move(done), std::placeholders::_1);
+
+  ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
 }
 
 CapturedFunction::CapturedFunction(const NameAttrList& func,
-                                   std::vector<Tensor> captured_inputs)
+                                   std::vector<Tensor> captured_inputs,
+                                   bool use_inter_op_parallelism)
     : func_(func),
       lib_(nullptr),
       f_handle_(kInvalidHandle),
-      captured_inputs_(std::move(captured_inputs)) {}
+      captured_inputs_(std::move(captured_inputs)),
+      use_inter_op_parallelism_(use_inter_op_parallelism) {}
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index c95f2b1..8b420fa 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -32,6 +32,8 @@
 class OpKernelContext;
 class ResourceMgr;
 
+namespace data {
+
 // A `CapturedFunction` encapsulates a TensorFlow function and all of
 // the runtime support required to execute it.
 //
@@ -48,6 +50,15 @@
                        std::vector<Tensor> captured_inputs,
                        std::unique_ptr<CapturedFunction>* out_function);
 
+  // Creates a new instance from a list of named attributes and captured inputs.
+  //
+  // If `use_inter_op_parallelism` is false, the runtime may use an executor
+  // that is optimized for small functions.
+  static Status Create(const NameAttrList& func,
+                       std::vector<Tensor> captured_inputs,
+                       bool use_inter_op_parallelism,
+                       std::unique_ptr<CapturedFunction>* out_function);
+
   // Creates a new instance using a list of named attributes, fetching captured
   // inputs from a context argument.
   static Status Create(const NameAttrList& func, OpKernelContext* ctx,
@@ -93,7 +104,8 @@
   // in order to be able to deallocate them as early as possible.
   void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
                 std::vector<Tensor>* rets,
-                FunctionLibraryRuntime::DoneCallback done);
+                FunctionLibraryRuntime::DoneCallback done,
+                const string& prefix);
 
   // Returns the named list of function arguments.
   const NameAttrList& func() { return func_; }
@@ -114,7 +126,8 @@
 
  private:
   CapturedFunction(const NameAttrList& func,
-                   std::vector<Tensor> captured_inputs);
+                   std::vector<Tensor> captured_inputs,
+                   bool use_inter_op_parallelism);
 
   Status GetHandle(IteratorContext* ctx,
                    FunctionLibraryRuntime::Handle* out_handle);
@@ -126,10 +139,17 @@
   const std::vector<Tensor> captured_inputs_;
   DataTypeSlice ret_types_;
   std::function<void(std::function<void()>)> captured_runner_ = nullptr;
+  const bool use_inter_op_parallelism_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
 };
 
+}  // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::CapturedFunction;
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index c361a9a..a04f150 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -195,5 +195,5 @@
                         ConcatenateDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index c71d027..bd1ccd5 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
+namespace data {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
 // description of the following op.
@@ -48,4 +49,5 @@
 REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
                         DatasetToGraphOp);
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index d85ef1c..e7ac368 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -17,8 +17,7 @@
 #include "tensorflow/core/common_runtime/device.h"
 
 namespace tensorflow {
-
-namespace dataset {
+namespace data {
 
 Status MakeIteratorFromInputElement(
     IteratorContext* ctx, const std::vector<Tensor>& input_element,
@@ -45,6 +44,5 @@
       ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
 }
 
-}  // namespace dataset
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6c4191c..234856e 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -20,16 +20,14 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
-namespace dataset {
+namespace data {
 
 Status MakeIteratorFromInputElement(
     IteratorContext* ctx, const std::vector<Tensor>& input_element,
     int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
     std::unique_ptr<IteratorBase>* out_iterator);
 
-}  // namespace dataset
-
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 9770bc0..237511a 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -301,5 +301,5 @@
                         DenseToSparseBatchDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index ce57739..a7e3a56 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -166,5 +166,5 @@
                         FilterByLastComponentDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bbce001..bf0aeca 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -280,5 +280,5 @@
                         FilterDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index b1eb2fd..e3c45ef 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -245,7 +245,7 @@
      private:
       Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-        return dataset::MakeIteratorFromInputElement(
+        return MakeIteratorFromInputElement(
             ctx, captured_func_inputs_, element_index_++,
             dataset()->captured_func_.get(), prefix(),
             &current_element_iterator_);
@@ -285,5 +285,5 @@
                         FlatMapDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ccee690..ac5cc1b 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
+namespace data {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
 // description of the following op.
@@ -188,10 +189,13 @@
                   std::move(finalize_func), output_types_, output_shapes_);
 }
 
+namespace {
 REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
                         GeneratorDatasetOp);
 REGISTER_KERNEL_BUILDER(
     Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
     GeneratorDatasetOp);
+}  // namespace
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
index 8407543..d23ed97 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.h
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -19,6 +19,7 @@
 #include "tensorflow/core/framework/dataset.h"
 
 namespace tensorflow {
+namespace data {
 
 class GeneratorDatasetOp : public DatasetOpKernel {
  public:
@@ -36,5 +37,6 @@
   NameAttrList finalize_func_;
 };
 
+}  // namespace data
 }  // namespace tensorflow
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index 130f04d..d6ee42a 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -433,4 +434,5 @@
                         GroupByReducerDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 46a3185..e4fa557 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -23,6 +23,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -549,4 +550,5 @@
                         GroupByWindowDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 716e040..0768f46 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -201,7 +201,7 @@
             TF_RETURN_IF_ERROR(input_impl_->GetNext(
                 ctx, &args_list_[cycle_index_], &end_of_input_));
             if (!end_of_input_) {
-              TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+              TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
                   ctx, args_list_[cycle_index_], cycle_index_,
                   dataset()->captured_func_.get(), prefix(),
                   &current_elements_[cycle_index_]));
@@ -288,7 +288,7 @@
                   full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
                   &args_list_[idx][i]));
             }
-            TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+            TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
                 ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
                 prefix(), &current_elements_[idx]));
             TF_RETURN_IF_ERROR(
@@ -330,5 +330,5 @@
                         InterleaveDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 4e9b280..30c6585 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -36,7 +36,7 @@
 #include "tensorflow/core/public/session_options.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -236,6 +236,8 @@
   const std::vector<PartialTensorShape> output_shapes_;
 };
 
+namespace {
+
 // Helper class for reading data from a VariantTensorData object.
 class VariantTensorDataReader : public IteratorStateReader {
  public:
@@ -401,12 +403,12 @@
   }
   string TypeName() const { return kIteratorVariantTypeName; }
   void Encode(VariantTensorData* data) const { *data = *data_; }
-  bool Decode(const VariantTensorData& data) {
+  bool Decode(VariantTensorData data) {
     if (data.type_name() != TypeName()) {
       return false;
     }
     std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
-    *tensor_data = data;
+    std::swap(*tensor_data, data);
     std::unique_ptr<VariantTensorDataReader> reader(
         new VariantTensorDataReader(tensor_data.get()));
     status_ = reader->status();
@@ -443,6 +445,8 @@
 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
                                        kIteratorVariantTypeName);
 
+}  // namespace
+
 // Note that IteratorHandleOp holds a reference to the resource it creates. If
 // cleaning up resources with DestroyResourceOp is important, consider creating
 // resource containers with AnonymousIteratorHandleOp instead.
@@ -622,6 +626,8 @@
   OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
 }
 
+namespace {
+
 class ToSingleElementOp : public AsyncOpKernel {
  public:
   explicit ToSingleElementOp(OpKernelConstruction* ctx)
@@ -887,6 +893,8 @@
   const int graph_def_version_;
 };
 
+}  // namespace
+
 void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
   IteratorResource* iterator;
   OP_REQUIRES_OK_ASYNC(
@@ -957,6 +965,8 @@
   }
 }
 
+namespace {
+
 class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
  public:
   explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
@@ -1037,6 +1047,8 @@
   std::vector<PartialTensorShape> output_shapes_;
 };
 
+}  // namespace
+
 void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
   const Tensor& resource_handle_t = ctx->input(0);
   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
@@ -1108,6 +1120,8 @@
   resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
 }
 
+namespace {
+
 class SerializeIteratorOp : public OpKernel {
  public:
   explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1202,4 +1216,7 @@
 REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
                         DeserializeIteratorOp);
 
+}  // namespace
+
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index 7235642..8a2b263 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/kernels/ops_util.h"
 
 namespace tensorflow {
+namespace data {
 
 class IteratorResource;
 
@@ -142,6 +143,7 @@
   std::vector<PartialTensorShape> output_shapes_;
 };
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 8b0c9ad..85e4935 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -29,7 +29,7 @@
 #include "tensorflow/core/platform/tracing.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -204,6 +204,8 @@
       }
 
       Status Initialize(IteratorContext* ctx) override {
+        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
+        SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
         TF_RETURN_IF_ERROR(
             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         return dataset()->captured_func_->Instantiate(ctx);
@@ -218,7 +220,9 @@
           EnsureRunnerThreadStarted(ctx);
           while (batch_results_.empty() ||
                  batch_results_.front()->num_calls > 0) {
+            StopWork(ctx);
             cond_var_.wait(l);
+            StartWork(ctx);
           }
           std::swap(result, batch_results_.front());
           batch_results_.pop_front();
@@ -365,7 +369,8 @@
                   ctx.get(), std::move(input_element), return_values.get(),
                   [this, ctx, result, return_values, offset](Status status) {
                     Callback(ctx, result, return_values, offset, status);
-                  });
+                  },
+                  prefix());
             },
             ctx, std::move(input_element)));
       }
@@ -476,6 +481,9 @@
           LOCKS_EXCLUDED(mu_) {
         std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
         new_calls.reserve(dataset()->num_parallel_calls_);
+        StartWork(ctx.get());
+        auto stop_cleanup =
+            gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
         while (true) {
           {
             mutex_lock l(mu_);
@@ -484,7 +492,9 @@
                     batch_results_.size() > MaxBatchResults() ||
                     (batch_results_.size() == MaxBatchResults() &&
                      call_counter_ % dataset()->batch_size_ == 0))) {
+              StopWork(ctx.get());
               cond_var_.wait(l);
+              StartWork(ctx.get());
             }
 
             if (cancelled_) {
@@ -675,5 +685,5 @@
                         MapAndBatchDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 7f8182d..af301e2 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,12 +28,12 @@
 
 class MapDatasetOp : public UnaryDatasetOpKernel {
  public:
-  explicit MapDatasetOp(OpKernelConstruction* ctx)
-      : UnaryDatasetOpKernel(ctx),
-        graph_def_version_(ctx->graph_def_version()) {
+  explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+                                     &use_inter_op_parallelism_));
   }
 
   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
@@ -48,7 +48,8 @@
 
     std::unique_ptr<CapturedFunction> captured_func;
     OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+                            func_, std::move(other_arguments),
+                            use_inter_op_parallelism_, &captured_func));
 
     *output = new Dataset(ctx, input, func_, std::move(captured_func),
                           output_types_, output_shapes_);
@@ -183,14 +184,14 @@
     const std::vector<PartialTensorShape> output_shapes_;
   };
 
-  const int graph_def_version_;
   DataTypeVector output_types_;
   std::vector<PartialTensorShape> output_shapes_;
   NameAttrList func_;
+  bool use_inter_op_parallelism_;
 };
 
 REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 607d0ca..6657f2b 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,18 +18,20 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/util/batch_util.h"
 #include "tensorflow/core/util/reffed_status_callback.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
                    bool always_collect_stats) {
   opts->step_id = ctx->step_id();
   opts->rendezvous = ctx->rendezvous();
-  opts->cancellation_manager = ctx->cancellation_manager();
   if (always_collect_stats) {
     opts->stats_collector = ctx->stats_collector();
   }
@@ -60,103 +62,186 @@
 
   ~MapDefunOp() override {}
 
-  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
-    int64 batch_size = ctx->input(0).dim_size(0);
-    // Inputs
-    auto* args = new std::vector<Tensor>;
-    auto* arg_shapes = new std::vector<TensorShape>;
-    arg_shapes->reserve(ctx->num_inputs());
-    args->reserve(ctx->num_inputs());
-
+  Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
+    // Validates inputs and gets the size of their leading dimension.
+    *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
     for (size_t i = 0; i < ctx->num_inputs(); ++i) {
-      args->push_back(ctx->input(i));
-      arg_shapes->push_back(ctx->input(i).shape());
-      arg_shapes->at(i).RemoveDim(0);  // Remove the first batch dimension
-      OP_REQUIRES_ASYNC(
-          ctx, batch_size == ctx->input(i).dim_size(0),
-          errors::InvalidArgument(
-              "All inputs must have the same dimension 0. Input ", i,
-              " has leading dimension ", ctx->input(i).dim_size(0),
-              ", while all previous inputs have leading dimension ", batch_size,
-              "."),
-          done);
+      if (ctx->input(i).dims() == 0) {
+        return errors::InvalidArgument(
+            "All inputs must have rank at least 1. Input ", i,
+            " has a rank of 0.");
+      } else if (ctx->input(i).dim_size(0) != *batch_size) {
+        return errors::InvalidArgument(
+            "All inputs must have the same dimension 0. Input ", i,
+            " has leading dimension ", ctx->input(i).dim_size(0),
+            ", while all previous inputs have leading dimension ", batch_size);
+      }
     }
+    return Status::OK();
+  }
 
-    // Outputs
-    auto* output = new OpOutputList;
-    OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
+  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+    ComputeOptions* compute_opts = nullptr;
 
-    for (size_t i = 0; i < output_types().size(); ++i) {
-      Tensor* out = nullptr;
-      TensorShape output_shape = output_shapes_.at(i);
-      output_shape.InsertDim(0, batch_size);
-      OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
-    }
+    OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
 
-    SetRunOptions(ctx, &opts_, false);
+    Status s = SetupOutputs(ctx, compute_opts);
+    if (!s.ok()) delete compute_opts;
+    OP_REQUIRES_OK_ASYNC(ctx, s, done);
+
+    FunctionLibraryRuntime::Options opts;
+    SetRunOptions(ctx, &opts, false);
 
     // Run loop
     StatusCallback callback = std::bind(
-        [](OpKernelContext* ctx, std::vector<Tensor>* args,
-           std::vector<TensorShape>* arg_shapes, OpOutputList* output,
+        [](OpKernelContext* ctx, ComputeOptions* compute_opts,
            DoneCallback& done, const Status& status) {
-          delete args;
-          delete arg_shapes;
-          delete output;
+          delete compute_opts;
           ctx->SetStatus(status);
           done();
         },
-        ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+        ctx, compute_opts, std::move(done), std::placeholders::_1);
 
     auto* refcounted = new ReffedStatusCallback(std::move(callback));
 
-    for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
-      // Start from i = 1 because refcounted is initialized with refcount = 1
+    CancellationManager* parent_mgr = ctx->cancellation_manager();
+
+    for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+      // We use a different cancellation manager each time the function is run
+      // to avoid the race condition between a function run error and other
+      // functions being cancelled as a result.
+      CancellationManager* c_mgr = new CancellationManager;
+      CancellationToken token = parent_mgr->get_cancellation_token();
+      const bool success = parent_mgr->RegisterCallback(
+          token, [c_mgr]() { c_mgr->StartCancel(); });
+
+      opts.cancellation_manager = c_mgr;
+      if (!success) {
+        delete c_mgr;
+        refcounted->UpdateStatus(errors::Cancelled(
+            "MapDefunOp functions cancelled because parent graph cancelled"));
+        break;
+      }
+
+      auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
       refcounted->Ref();
+      ctx->function_library()->Run(opts, func_handle_, call_frame,
+                                   [call_frame, refcounted, c_mgr, parent_mgr,
+                                    token](const Status& func_status) {
+                                     parent_mgr->DeregisterCallback(token);
+                                     delete c_mgr;
+                                     delete call_frame;
+                                     refcounted->UpdateStatus(func_status);
+                                     refcounted->Unref();
+                                   });
     }
-    for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
-      auto* call_frame =
-          new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
-      ctx->function_library()->Run(
-          opts_, func_handle_, call_frame,
-          [call_frame, refcounted](const Status& func_status) {
-            delete call_frame;
-            refcounted->UpdateStatus(func_status);
-            refcounted->Unref();
-          });
-    }
+
+    // Unref 1 because refcounted is initialized with refcount = 1
+    refcounted->Unref();
   }
 
  private:
   FunctionLibraryRuntime::Handle func_handle_;
-  FunctionLibraryRuntime::Options opts_;
-  std::vector<TensorShape> output_shapes_;
+  std::vector<PartialTensorShape> output_shapes_;
+
+  struct ComputeOptions {
+    // These vary per MapDefunOp::ComputeAsync call, but must persist until
+    // all calls to the function are complete. This struct also encapsulates
+    // all the components that need to be passed to each MapFunctionCallFrame.
+
+    const std::vector<Tensor> args;
+    const std::vector<TensorShape> arg_shapes;
+    const int64 batch_size;
+
+    // Output of a compute call
+    std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+    OpOutputList output GUARDED_BY(mu);
+    mutex mu;
+
+    // Create a copy of output_shapes because every `Compute` may expect a
+    // different output shape.
+    ComputeOptions(std::vector<Tensor> args,
+                   std::vector<TensorShape> arg_shapes, int64 batch_size,
+                   const std::vector<PartialTensorShape>& output_shapes_attr)
+        : args(std::move(args)),
+          arg_shapes(std::move(arg_shapes)),
+          batch_size(batch_size),
+          output_shapes(output_shapes_attr) {}
+  };
+
+  // Get inputs to Compute and check that they are valid.
+  Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+    int64 batch_size =
+        ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+    for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+      if (ctx->input(i).dims() == 0) {
+        return errors::InvalidArgument(
+            "All inputs must have rank at least 1. Input ", i,
+            " has a rank of 0.");
+      } else if (ctx->input(i).dim_size(0) != batch_size) {
+        return errors::InvalidArgument(
+            "All inputs must have the same dimension 0. Input ", i,
+            " has leading dimension ", ctx->input(i).dim_size(0),
+            ", while all previous inputs have leading dimension ", batch_size);
+      }
+    }
+
+    std::vector<Tensor> args;
+    std::vector<TensorShape> arg_shapes;
+    args.reserve(ctx->num_inputs());
+    arg_shapes.reserve(ctx->num_inputs());
+
+    for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+      args.push_back(ctx->input(i));
+      arg_shapes.push_back(ctx->input(i).shape());
+      arg_shapes.at(i).RemoveDim(0);
+    }
+
+    *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+                                       batch_size, output_shapes_);
+    return Status::OK();
+  }
+
+  Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+    mutex_lock l(opts->mu);
+    TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+    for (size_t i = 0; i < output_types().size(); ++i) {
+      if (output_shapes_.at(i).IsFullyDefined()) {
+        Tensor* out = nullptr;
+        TensorShape output_shape;
+        output_shapes_.at(i).AsTensorShape(&output_shape);
+        output_shape.InsertDim(0, opts->batch_size);
+        TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+      }
+    }
+    return Status::OK();
+  }
 
   class MapFunctionCallFrame : public CallFrameInterface {
    public:
-    MapFunctionCallFrame(const std::vector<Tensor>& args,
-                         const std::vector<TensorShape>& arg_shapes,
-                         OpOutputList* output, OpKernel* kernel, size_t iter)
-        : args_(args),
-          arg_shapes_(arg_shapes),
-          output_(output),
-          kernel_(kernel),
-          iter_(iter) {}
+    MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+                         size_t iter)
+        : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
 
     ~MapFunctionCallFrame() override {}
 
-    size_t num_args() const override { return args_.size(); }
+    size_t num_args() const override { return compute_opts_->args.size(); }
+
     size_t num_retvals() const override {
       return static_cast<size_t>(kernel_->num_outputs());
     }
 
     Status GetArg(int index, Tensor* val) const override {
-      if (index < 0 || index >= args_.size()) {
+      if (index < 0 || index >= compute_opts_->args.size()) {
         return errors::InvalidArgument(
             "Mismatch in number of function inputs.");
       }
-      bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
-                                  arg_shapes_.at(index));
+      bool result =
+          val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+                        compute_opts_->arg_shapes.at(index));
       if (!result) {
         return errors::Internal("GetArg failed.");
       } else if (!val->IsAligned()) {
@@ -179,18 +264,39 @@
             "output: ",
             index);
       }
-      return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
+      {  // Locking scope
+        mutex_lock l(compute_opts_->mu);
+        if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+                val.shape())) {
+          return errors::InvalidArgument(
+              "Mismatch in function retval shape, ", val.shape(),
+              ", and expected output shape, ",
+              compute_opts_->output_shapes.at(index).DebugString(), ".");
+        }
+        if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
+          // Given val, we have new information about the output shape at
+          // this index. Store the shape and allocate the output accordingly.
+          compute_opts_->output_shapes.at(index) = val.shape();
+
+          Tensor* out = nullptr;
+          TensorShape actual_shape = val.shape();
+          actual_shape.InsertDim(0, compute_opts_->batch_size);
+          TF_RETURN_IF_ERROR(
+              compute_opts_->output.allocate(index, actual_shape, &out));
+        }
+        return batch_util::CopyElementToSlice(
+            val, (compute_opts_->output)[index], iter_);
+      }
     }
 
    private:
-    const std::vector<Tensor>& args_;
-    const std::vector<TensorShape>& arg_shapes_;
-    OpOutputList* output_;
+    ComputeOptions* const compute_opts_;  // Not owned
     const OpKernel* kernel_;
     const size_t iter_;
   };
-};  // namespace
+};
 
 REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000..c7f929d
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,127 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+  explicit ModelDatasetOp(OpKernelConstruction* ctx)
+      : UnaryDatasetOpKernel(ctx) {}
+
+  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+                   DatasetBase** output) override {
+    *output = new Dataset(ctx, input);
+  }
+
+ private:
+  class Dataset : public DatasetBase {
+   public:
+    explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+        : DatasetBase(DatasetContext(ctx)), input_(input) {
+      input_->Ref();
+    }
+
+    ~Dataset() override { input_->Unref(); }
+
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(
+          new Iterator({this, strings::StrCat(prefix, "::Model")}));
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      return input_->output_dtypes();
+    }
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      return input_->output_shapes();
+    }
+
+    string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+   protected:
+    Status AsGraphDefInternal(SerializationContext* ctx,
+                              DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+      TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params), model_(new model::Model()) {}
+
+      ~Iterator() override { model_->OutputToFile(); }
+
+      Status Initialize(IteratorContext* ctx) override {
+        IteratorContext ctx_with_model(CreateParams(ctx));
+        return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+                                               &input_impl_);
+      }
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        mutex_lock l(mu_);
+        IteratorContext ctx_with_model(CreateParams(ctx));
+        return input_impl_->GetNext(&ctx_with_model, out_tensors,
+                                    end_of_sequence);
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+        return Status::OK();
+      }
+
+      IteratorContext::Params CreateParams(IteratorContext* ctx) {
+        IteratorContext::Params params = ctx->params();
+        params.model = model_;
+        return params;
+      }
+
+     private:
+      mutex mu_;
+      std::shared_ptr<model::Model> model_;
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+    };
+
+    const DatasetBase* input_;
+  };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+                        ModelDatasetOp);
+}  // namespace
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 6263dc3..d5b725e 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -33,6 +33,7 @@
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -270,4 +271,5 @@
                         OptimizeDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index cfac45d..6180df5 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/variant_op_registry.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
 
@@ -230,10 +231,9 @@
   return Status::OK();
 }
 
-#define REGISTER_OPTIONAL_COPY(DIRECTION)                   \
-  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(     \
-      OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
-      OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION)               \
+  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+      OptionalVariant, DIRECTION, OptionalDeviceCopy)
 
 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -267,4 +267,5 @@
   return Status::OK();
 }
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
index 6f25567..2cbf293 100644
--- a/tensorflow/core/kernels/data/optional_ops.h
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -21,6 +21,7 @@
 #include "tensorflow/core/framework/variant_tensor_data.h"
 
 namespace tensorflow {
+namespace data {
 
 // Stores a DT_VARIANT value representing an Optional with the given value
 // in the `output_index`^th output of the given kernel execution context.
@@ -31,6 +32,7 @@
 // in the `output_index`^th output of the given kernel execution context.
 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index be45eac..73eeafd 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -19,7 +19,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -207,6 +207,7 @@
           : DatasetIterator<Dataset>(params) {}
 
       Status Initialize(IteratorContext* ctx) override {
+        SetMetadata(ctx, "batch_size", dataset()->batch_size_);
         return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
       }
 
@@ -382,5 +383,5 @@
                         PaddedBatchDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index f6b3fd9..aa5e613 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include <deque>
+#include <utility>
 
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,11 +22,12 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 #include "tensorflow/core/kernels/data/dataset_utils.h"
 #include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -34,8 +36,7 @@
 class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
-      : UnaryDatasetOpKernel(ctx),
-        graph_def_version_(ctx->graph_def_version()) {
+      : UnaryDatasetOpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -125,6 +126,7 @@
     const DataTypeVector& output_dtypes() const override {
       return output_types_;
     }
+
     const std::vector<PartialTensorShape>& output_shapes() const override {
       return output_shapes_;
     }
@@ -250,6 +252,7 @@
       }
 
       Status Initialize(IteratorContext* ctx) override {
+        SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
         TF_RETURN_IF_ERROR(
             dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         return dataset()->captured_func_->Instantiate(ctx);
@@ -349,11 +352,13 @@
 
           if (must_wait_for_input) {
             // Wait for elements to become available.
+            StopWork(ctx);
             if (dataset()->sloppy_) {
               sloppy_cond_var_.wait(l);
             } else {
               workers_[interleave_indices_[next_index_]].cond_var.wait(l);
             }
+            StartWork(ctx);
           }
         }
         return errors::Cancelled(
@@ -482,10 +487,10 @@
         if (reader->Contains(full_name("worker_threads_running"))) {
           worker_threads_.reserve(dataset()->num_threads());
           for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+            std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
             worker_threads_.emplace_back(ctx->env()->StartThread(
                 {}, "worker_thread",
-                std::bind(&Iterator::WorkerThread, this,
-                          new IteratorContext(*ctx), i)));
+                [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
           }
         }
         return Status::OK();
@@ -581,10 +586,10 @@
               return Status::OK();
             }
             workers_[i].SetInputs(s, std::move(args));
+            std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
             worker_threads_.emplace_back(ctx->env()->StartThread(
                 {}, "worker_thread",
-                std::bind(&Iterator::WorkerThread, this,
-                          new IteratorContext(*ctx), i)));
+                [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
             if (i < dataset()->cycle_length_) {
               interleave_indices_.push_back(i);
             } else {
@@ -599,7 +604,8 @@
       }
 
       // Produces elements into the worker's output buffers.
-      void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+      void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+                        const int64 thread_index) {
         // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
         //
         // 1. Any local state that may need to be checkpointed should be kept
@@ -620,10 +626,11 @@
 
         // std::function arguments are copy-constructable, so we pass raw
         // pointers, and then immediately wrap them to ensure correct ownership.
-        std::unique_ptr<IteratorContext> ctx(ctx_ptr);
-        auto cleanup = gtl::MakeCleanup([this, thread_index] {
+        StartWork(ctx.get());
+        auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
           mutex_lock l(mu_);
           workers_[thread_index].cond_var.notify_all();
+          StopWork(ctx.get());
         });
         bool make_new_iterator;
         {
@@ -649,9 +656,7 @@
           // 1. Build a new iterator or use the existing one.
           if (make_new_iterator) {
             // 1a. Get new input tensors or use the exiting ones.
-
             bool read_new_input;
-
             {
               tf_shared_lock l(ckpt_mu_);
               // worker_thread_states_[thread_index].input will be non-empty
@@ -663,7 +668,9 @@
             if (read_new_input) {
               mutex_lock l(mu_);
               while (!cancelled_ && !workers_[thread_index].is_producing) {
+                StopWork(ctx.get());
                 workers_[thread_index].cond_var.wait(l);
+                StartWork(ctx.get());
               }
               if (cancelled_) return;
               // Copy the input tensors so that we do not need to block on `mu_`
@@ -684,7 +691,7 @@
             {
               tf_shared_lock l(ckpt_mu_);
               worker_thread_states_[thread_index].iterator_creation_status =
-                  dataset::MakeIteratorFromInputElement(
+                  MakeIteratorFromInputElement(
                       ctx.get(), worker_thread_states_[thread_index].input,
                       thread_index, dataset()->captured_func_.get(), prefix(),
                       &worker_thread_states_[thread_index].iterator);
@@ -713,7 +720,9 @@
             // Wait for space in the prefetch queue.
             while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                       dataset()->buffer_output_elements_) {
+              StopWork(ctx.get());
               workers_[thread_index].cond_var.wait(l);
+              StartWork(ctx.get());
             }
             if (cancelled_) return;
             tf_shared_lock ckpt_l(ckpt_mu_);
@@ -762,7 +771,9 @@
                 // Wait for space in the prefetch queue.
                 while (!cancelled_ && workers_[thread_index].outputs.size() ==
                                           dataset()->buffer_output_elements_) {
+                  StopWork(ctx.get());
                   workers_[thread_index].cond_var.wait(l);
+                  StartWork(ctx.get());
                 }
                 if (cancelled_) return;
 
@@ -914,7 +925,7 @@
           worker_thread_states_[index].iterator.reset();
         } else {
           std::unique_ptr<IteratorBase> iterator;
-          Status s = dataset::MakeIteratorFromInputElement(
+          Status s = MakeIteratorFromInputElement(
               ctx, worker_thread_states_[index].input, index,
               dataset()->captured_func_.get(), prefix(), &iterator);
           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
@@ -1058,7 +1069,6 @@
     const std::vector<PartialTensorShape> output_shapes_;
   };
 
-  const int graph_def_version_;
   DataTypeVector output_types_;
   std::vector<PartialTensorShape> output_shapes_;
   NameAttrList interleave_func_;
@@ -1067,6 +1077,604 @@
 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
                         ParallelInterleaveDatasetOp);
 
-}  // namespace
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+  explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+      : UnaryDatasetOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+  }
 
+  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+                   DatasetBase** output) override {
+    OpInputList inputs;
+    OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
+
+    int64 cycle_length = 0;
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+    OP_REQUIRES(ctx, cycle_length > 0,
+                errors::InvalidArgument("`cycle_length` must be > 0"));
+
+    int64 block_length = 0;
+    OP_REQUIRES_OK(ctx,
+                   ParseScalarArgument(ctx, "block_length", &block_length));
+    OP_REQUIRES(ctx, block_length > 0,
+                errors::InvalidArgument("`block_length` must be > 0"));
+
+    int64 num_parallel_calls;
+    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+                                            &num_parallel_calls));
+    OP_REQUIRES(ctx, num_parallel_calls > 0,
+                errors::InvalidArgument(
+                    "num_parallel_calls must be greater than zero."));
+    OP_REQUIRES(
+        ctx, num_parallel_calls <= cycle_length,
+        errors::InvalidArgument(
+            "num_parallel_calls must less than or equal to cycle_length."));
+
+    // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
+    std::vector<Tensor> other_arguments;
+    other_arguments.reserve(inputs.size());
+    for (const Tensor& t : inputs) {
+      other_arguments.push_back(t);
+    }
+    std::unique_ptr<CapturedFunction> captured_func;
+    OP_REQUIRES_OK(
+        ctx, CapturedFunction::Create(
+                 interleave_func_, std::move(other_arguments), &captured_func));
+
+    *output = new Dataset(ctx, input, interleave_func_,
+                          std::move(captured_func), cycle_length, block_length,
+                          num_parallel_calls, output_types_, output_shapes_);
+  }
+
+ private:
+  class Dataset : public DatasetBase {
+   public:
+    Dataset(OpKernelContext* ctx, const DatasetBase* input,
+            const NameAttrList& func,
+            std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+            int64 block_length, int64 num_parallel_calls,
+            const DataTypeVector& output_types,
+            const std::vector<PartialTensorShape>& output_shapes)
+        : DatasetBase(DatasetContext(ctx)),
+          input_(input),
+          interleave_func_(func),
+          captured_func_(std::move(captured_func)),
+          cycle_length_(cycle_length),
+          block_length_(block_length),
+          num_parallel_calls_(num_parallel_calls),
+          output_types_(output_types),
+          output_shapes_(output_shapes) {
+      input_->Ref();
+    }
+
+    ~Dataset() override { input_->Unref(); }
+
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(new Iterator(
+          {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      return output_types_;
+    }
+
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      return output_shapes_;
+    }
+
+    string DebugString() const override {
+      return "ParallelInterleaveDatasetV2Op::Dataset";
+    }
+
+   protected:
+    Status AsGraphDefInternal(SerializationContext* ctx,
+                              DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+      Node* input_node;
+      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+      Node* cycle_length_node;
+      TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+      Node* block_length_node;
+      TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+      Node* num_parallel_calls_node;
+      TF_RETURN_IF_ERROR(
+          b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+      DataTypeVector other_arguments_types;
+      other_arguments_types.reserve(captured_func_->captured_inputs().size());
+      std::vector<Node*> other_arguments;
+      other_arguments.reserve(captured_func_->captured_inputs().size());
+      for (const Tensor& t : captured_func_->captured_inputs()) {
+        Node* node;
+        TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+        other_arguments.emplace_back(node);
+        other_arguments_types.emplace_back(t.dtype());
+      }
+      AttrValue f;
+      b->BuildAttrValue(interleave_func_, &f);
+      AttrValue other_arguments_types_attr;
+      b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+      TF_RETURN_IF_ERROR(b->AddDataset(
+          this,
+          {{0, input_node},
+           {2, cycle_length_node},
+           {3, block_length_node},
+           {4, num_parallel_calls_node}},
+          {{1, other_arguments}},
+          {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params),
+            args_list_(params.dataset->cycle_length_),
+            current_elements_(params.dataset->cycle_length_),
+            element_in_use_(params.dataset->cycle_length_, false),
+            thread_pool_(new thread::ThreadPool(
+                Env::Default(), ThreadOptions(), "parallel_interleave",
+                dataset()->cycle_length_ /* num_threads */,
+                false /* low_latency_hint */)) {}
+
+      ~Iterator() override {
+        mutex_lock l(mu_);
+        // Cancel the runner thread.
+        cancelled_ = true;
+        cond_var_.notify_all();
+        // Wait for all in-flight calls to complete.
+        while (num_calls_ > 0) {
+          cond_var_.wait(l);
+        }
+      }
+
+      Status Initialize(IteratorContext* ctx) override {
+        SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
+        TF_RETURN_IF_ERROR(
+            dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+        return dataset()->captured_func_->Instantiate(ctx);
+      }
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        std::shared_ptr<InvocationResult> result;
+        do {
+          {
+            mutex_lock l(mu_);
+            EnsureRunnerThreadStarted(ctx);
+            while (invocation_results_.empty() &&
+                   (!end_of_input_ || num_open_ > 0)) {
+              StopWork(ctx);
+              cond_var_.wait(l);
+              StartWork(ctx);
+            }
+            if (!invocation_results_.empty()) {
+              std::swap(result, invocation_results_.front());
+              invocation_results_.pop_front();
+            } else {
+              *end_of_sequence = true;
+              return Status::OK();
+            }
+          }
+          cond_var_.notify_all();
+          StopWork(ctx);
+          result->notification.WaitForNotification();
+          StartWork(ctx);
+        } while (result->skip);
+
+        if (result->status.ok()) {
+          *out_tensors = std::move(result->return_values);
+        }
+        *end_of_sequence = false;
+        return result->status;
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        // Wait for all in-flight calls to complete.
+        while (num_calls_ > 0) {
+          cond_var_.wait(l);
+        }
+        CHECK_EQ(num_calls_, 0);
+        TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            full_name("invocation_results.size"), invocation_results_.size()));
+        for (size_t i = 0; i < invocation_results_.size(); i++) {
+          std::shared_ptr<InvocationResult> result = invocation_results_[i];
+          TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+          TF_RETURN_IF_ERROR(writer->WriteScalar(
+              full_name(strings::StrCat("invocation_results[", i, "].size")),
+              result->return_values.size()));
+          for (size_t j = 0; j < result->return_values.size(); j++) {
+            TF_RETURN_IF_ERROR(writer->WriteTensor(
+                full_name(
+                    strings::StrCat("invocation_results[", i, "][", j, "]")),
+                result->return_values[j]));
+          }
+          if (result->skip) {
+            TF_RETURN_IF_ERROR(writer->WriteScalar(
+                full_name(strings::StrCat("invocation_results[", i, "].skip")),
+                ""));
+          }
+        }
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+        if (end_of_input_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("end_of_input"), ""));
+        }
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name("num_open"), num_open_));
+        TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+        int64 invocation_results_size;
+        TF_RETURN_IF_ERROR(reader->ReadScalar(
+            full_name("invocation_results.size"), &invocation_results_size));
+        for (size_t i = 0; i < invocation_results_size; i++) {
+          std::shared_ptr<InvocationResult> result(new InvocationResult());
+          invocation_results_.push_back(result);
+          TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+          size_t num_return_values;
+          {
+            int64 size;
+            TF_RETURN_IF_ERROR(reader->ReadScalar(
+                full_name(strings::StrCat("invocation_results[", i, "].size")),
+                &size));
+            num_return_values = static_cast<size_t>(size);
+            if (num_return_values != size) {
+              return errors::InvalidArgument(strings::StrCat(
+                  full_name(
+                      strings::StrCat("invocation_results[", i, "].size")),
+                  ": ", size, " is not a valid value of type size_t."));
+            }
+          }
+          result->return_values.reserve(num_return_values);
+          for (size_t j = 0; j < num_return_values; j++) {
+            result->return_values.emplace_back();
+            TF_RETURN_IF_ERROR(
+                reader->ReadTensor(full_name(strings::StrCat(
+                                       "invocation_results[", i, "][", j, "]")),
+                                   &result->return_values.back()));
+          }
+          result->skip = reader->Contains(
+              full_name(strings::StrCat("invocation_results[", i, "].skip")));
+          result->notification.Notify();
+        }
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+        if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name("num_open"), &num_open_));
+        TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+        return Status::OK();
+      }
+
+     private:
+      struct InvocationResult {
+        Notification notification;  // used for coordination with the consumer
+        Status status;              // the invocation status
+        std::vector<Tensor> return_values;  // the invocation result values
+        bool skip;  // if set the result should be skipped
+      };
+
+      void EnsureRunnerThreadStarted(IteratorContext* ctx)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        if (!runner_thread_) {
+          std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+          runner_thread_.reset(ctx->env()->StartThread(
+              {}, "runner_thread",
+              [this, new_ctx]() { RunnerThread(new_ctx); }));
+        }
+      }
+
+      // Fetches up to `results.size()` outputs from the cycle element at
+      // position `cycle_index`.
+      //
+      // If end of input is encountered, the `skip` field of the invocation
+      // result is used to identify results that should be skipped.
+      void FetchOutputs(
+          const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+          const std::vector<std::shared_ptr<InvocationResult>>& results)
+          LOCKS_EXCLUDED(mu_) {
+        StartWork(ctx.get());
+        auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+        bool end_of_input = false;
+        for (auto& result : results) {
+          if (!end_of_input) {
+            result->status = current_elements_[cycle_index]->GetNext(
+                ctx.get(), &result->return_values, &end_of_input);
+          }
+          if (end_of_input) {
+            result->skip = true;
+          }
+          result->notification.Notify();
+          if (!result->status.ok()) {
+            break;
+          }
+        }
+
+        // Release the ownership of the cycle element iterator, closing the
+        // iterator if end of input was encountered.
+        {
+          if (end_of_input) {
+            current_elements_[cycle_index].reset();
+          }
+          mutex_lock l(mu_);
+          element_in_use_[cycle_index] = false;
+          num_calls_--;
+          if (end_of_input) {
+            args_list_[cycle_index].clear();
+            num_open_--;
+          }
+        }
+        cond_var_.notify_all();
+      }
+
+      int64 MaxInvocationResults() {
+        return dataset()->cycle_length_ * dataset()->block_length_;
+      }
+
+      // Method responsible for 1) creating iterators out of input elements, 2)
+      // determining the order in which elements are fetched from the iterators,
+      // and 3) scheduling the fetching of the elements to a threadpool.
+      //
+      // This method runs in the `runner_thread` background thread.
+      void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+        StartWork(ctx.get());
+        auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
+        while (true) {
+          {
+            mutex_lock l(mu_);
+            // Wait until this thread is cancelled, the end of input has been
+            // reached, or the cycle element at the `cycle_index_` position is
+            // not in use and there is space in the `invocation_results_` queue.
+            while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+                   (element_in_use_[cycle_index_] ||
+                    num_calls_ >= dataset()->num_parallel_calls_ ||
+                    invocation_results_.size() >= MaxInvocationResults())) {
+              StopWork(ctx.get());
+              cond_var_.wait(l);
+              StartWork(ctx.get());
+            }
+
+            if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+              return;
+            }
+
+            while (!element_in_use_[cycle_index_] &&
+                   (!end_of_input_ || num_open_ > 0) &&
+                   num_calls_ < dataset()->num_parallel_calls_ &&
+                   invocation_results_.size() < MaxInvocationResults()) {
+              if (!current_elements_[cycle_index_]) {
+                // Try to create a new iterator from the next input element.
+                Status status = input_impl_->GetNext(
+                    ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+                if (!status.ok()) {
+                  invocation_results_.emplace_back(new InvocationResult());
+                  std::shared_ptr<InvocationResult>& result =
+                      invocation_results_.back();
+                  result->status.Update(status);
+                  result->notification.Notify();
+                  break;
+                }
+                if (!end_of_input_) {
+                  Status status = MakeIteratorFromInputElement(
+                      ctx.get(), args_list_[cycle_index_], cycle_index_,
+                      dataset()->captured_func_.get(), prefix(),
+                      &current_elements_[cycle_index_]);
+                  if (!status.ok()) {
+                    invocation_results_.emplace_back(new InvocationResult());
+                    std::shared_ptr<InvocationResult>& result =
+                        invocation_results_.back();
+                    result->status.Update(status);
+                    result->notification.Notify();
+                    break;
+                  }
+                  ++num_open_;
+                }
+              }
+              if (current_elements_[cycle_index_]) {
+                // Pre-allocate invocation results for outputs to be fetched
+                // and then fetch the outputs asynchronously.
+                std::vector<std::shared_ptr<InvocationResult>> results;
+                results.reserve(dataset()->block_length_);
+                for (int i = 0; i < dataset()->block_length_; ++i) {
+                  invocation_results_.emplace_back(new InvocationResult());
+                  results.push_back(invocation_results_.back());
+                }
+                num_calls_++;
+                element_in_use_[cycle_index_] = true;
+                thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+                                                 ctx, cycle_index_,
+                                                 std::move(results)));
+              }
+              cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+            }
+          }
+          cond_var_.notify_all();
+        }
+      }
+
+      Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+                               const Status& status)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        TF_RETURN_IF_ERROR(writer->WriteScalar(
+            CodeKey(index), static_cast<int64>(status.code())));
+        if (!status.ok()) {
+          TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+                                                 status.error_message()));
+        }
+        return Status::OK();
+      }
+
+      Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+                              Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        int64 code_int;
+        TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+        error::Code code = static_cast<error::Code>(code_int);
+
+        if (code != error::Code::OK) {
+          string error_message;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(ErrorMessageKey(index), &error_message));
+          *status = Status(code, error_message);
+        } else {
+          *status = Status::OK();
+        }
+        return Status::OK();
+      }
+
+      string CodeKey(size_t index) {
+        return full_name(
+            strings::StrCat("invocation_results[", index, "].code"));
+      }
+
+      string ErrorMessageKey(size_t index) {
+        return full_name(
+            strings::StrCat("invocation_results[", index, "].error_message"));
+      }
+
+      Status WriteCurrentElements(IteratorStateWriter* writer)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        for (int idx = 0; idx < current_elements_.size(); idx++) {
+          if (current_elements_[idx]) {
+            TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+            TF_RETURN_IF_ERROR(writer->WriteScalar(
+                full_name(strings::StrCat("args_size[", idx, "]")),
+                args_list_[idx].size()));
+            for (int i = 0; i < args_list_[idx].size(); i++) {
+              TF_RETURN_IF_ERROR(writer->WriteTensor(
+                  full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+                  args_list_[idx][i]));
+            }
+          }
+        }
+        return Status::OK();
+      }
+
+      Status ReadCurrentElements(IteratorContext* ctx,
+                                 IteratorStateReader* reader)
+          EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+        for (int idx = 0; idx < current_elements_.size(); idx++) {
+          if (reader->Contains(
+                  full_name(strings::StrCat("args_size[", idx, "]")))) {
+            int64 args_size;
+            TF_RETURN_IF_ERROR(reader->ReadScalar(
+                full_name(strings::StrCat("args_size[", idx, "]")),
+                &args_size));
+            args_list_[idx].resize(args_size);
+            for (int i = 0; i < args_size; i++) {
+              TF_RETURN_IF_ERROR(reader->ReadTensor(
+                  full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+                  &args_list_[idx][i]));
+            }
+            TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+                ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+                prefix(), &current_elements_[idx]));
+            TF_RETURN_IF_ERROR(
+                RestoreInput(ctx, reader, current_elements_[idx]));
+          } else {
+            current_elements_[idx].reset();
+          }
+        }
+        return Status::OK();
+      }
+
+      // Used for coordination between the main thread, the runner thread, and
+      // the worker threads.
+      mutex mu_;
+
+      // Used for coordination between the main thread, the runner thread, and
+      // the worker threads. In particular, the runner thread should only
+      // schedule new calls when the number of in-flight calls is less than the
+      // user specified level of parallelism, there are slots available in the
+      // `invocation_results_` buffer, the current cycle element is not in use,
+      // and there are elements left to be fetched.
+      condition_variable cond_var_;
+
+      // Iterator for input elements.
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+
+      // Identifies current cycle element.
+      int64 cycle_index_ = 0;
+
+      // Arguments for creating an iterator for cycle elements.
+      std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+
+      // Iterators for the current cycle elements. Concurrent access is
+      // protected by `element_in_use_`.
+      std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+      // Identifies cycle elements that are in use by worker threads.
+      std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+
+      // Buffer for storing the invocation results.
+      std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+          GUARDED_BY(mu_);
+
+      // Identifies whether end of input has been reached.
+      bool end_of_input_ GUARDED_BY(mu_) = false;
+
+      // Identifies the number of open iterators.
+      int64 num_open_ GUARDED_BY(mu_) = 0;
+
+      // Identifies the number of outstanding calls.
+      int64 num_calls_ GUARDED_BY(mu_) = 0;
+
+      std::unique_ptr<thread::ThreadPool> thread_pool_;
+      std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+
+      // Identifies whether background activity should be cancelled.
+      bool cancelled_ GUARDED_BY(mu_) = false;
+    };
+
+    const DatasetBase* const input_;
+    const NameAttrList interleave_func_;
+    const std::unique_ptr<CapturedFunction> captured_func_;
+    const int64 cycle_length_;
+    const int64 block_length_;
+    const int64 num_parallel_calls_;
+    const DataTypeVector output_types_;
+    const std::vector<PartialTensorShape> output_shapes_;
+  };
+
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
+  NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+                        ParallelInterleaveDatasetV2Op);
+
+}  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index bff5481..0795987 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -24,7 +24,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,11 +33,12 @@
 class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit ParallelMapDatasetOp(OpKernelConstruction* ctx)
-      : UnaryDatasetOpKernel(ctx),
-        graph_def_version_(ctx->graph_def_version()) {
+      : UnaryDatasetOpKernel(ctx) {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+                                     &use_inter_op_parallelism_));
   }
 
  protected:
@@ -60,10 +61,12 @@
 
     std::unique_ptr<CapturedFunction> captured_func;
     OP_REQUIRES_OK(ctx, CapturedFunction::Create(
-                            func_, std::move(other_arguments), &captured_func));
+                            func_, std::move(other_arguments),
+                            use_inter_op_parallelism_, &captured_func));
 
     *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
-                          output_shapes_, std::move(captured_func));
+                          output_shapes_, use_inter_op_parallelism_,
+                          std::move(captured_func));
   }
 
  private:
@@ -73,6 +76,7 @@
             const NameAttrList& func, int32 num_parallel_calls,
             const DataTypeVector& output_types,
             const std::vector<PartialTensorShape>& output_shapes,
+            bool use_inter_op_parallelism,
             std::unique_ptr<CapturedFunction> captured_func)
         : DatasetBase(DatasetContext(ctx)),
           input_(input),
@@ -80,6 +84,7 @@
           num_parallel_calls_(num_parallel_calls),
           output_types_(output_types),
           output_shapes_(output_shapes),
+          use_inter_op_parallelism_(use_inter_op_parallelism),
           captured_func_(std::move(captured_func)) {
       input_->Ref();
     }
@@ -92,16 +97,26 @@
         return captured_func_->Instantiate(ctx);
       };
 
-      auto map_func = [this](IteratorContext* ctx,
+      const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+      ParallelMapIteratorFunction map_func =
+          [this, new_prefix](IteratorContext* ctx,
                              std::vector<Tensor> input_element,
                              std::vector<Tensor>* result, StatusCallback done) {
-        captured_func_->RunAsync(ctx, std::move(input_element), result,
-                                 std::move(done));
-      };
+            captured_func_->RunAsync(ctx, std::move(input_element), result,
+                                     std::move(done), new_prefix);
+          };
+      if (!use_inter_op_parallelism_) {
+        map_func = [map_func](
+                       IteratorContext* ctx, std::vector<Tensor> input_element,
+                       std::vector<Tensor>* result, StatusCallback done) {
+          (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+                                     result, std::move(done)));
+        };
+      }
 
-      return NewParallelMapIterator(
-          {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
-          std::move(init_func), std::move(map_func), num_parallel_calls_);
+      return NewParallelMapIterator({this, new_prefix}, input_,
+                                    std::move(init_func), std::move(map_func),
+                                    num_parallel_calls_);
     }
 
     const DataTypeVector& output_dtypes() const override {
@@ -167,12 +182,13 @@
     const int32 num_parallel_calls_;
     const DataTypeVector output_types_;
     const std::vector<PartialTensorShape> output_shapes_;
+    const bool use_inter_op_parallelism_;
     const std::unique_ptr<CapturedFunction> captured_func_;
   };
 
-  const int graph_def_version_;
   DataTypeVector output_types_;
   std::vector<PartialTensorShape> output_shapes_;
+  bool use_inter_op_parallelism_;
   NameAttrList func_;
 };
 
@@ -180,5 +196,5 @@
                         ParallelMapDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 61f8139..0b6e587 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -19,7 +19,10 @@
 #include <utility>
 #include <vector>
 
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
 namespace tensorflow {
+namespace data {
 namespace {
 
 class ParallelMapIterator : public DatasetBaseIterator {
@@ -52,6 +55,7 @@
   }
 
   Status Initialize(IteratorContext* ctx) override {
+    SetMetadata(ctx, "parallelism", num_parallel_calls_);
     TF_RETURN_IF_ERROR(
         input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
     if (init_func_) {
@@ -67,13 +71,17 @@
       mutex_lock l(mu_);
       EnsureRunnerThreadStarted(ctx);
       while (invocation_results_.empty()) {
+        StopWork(ctx);
         cond_var_.wait(l);
+        StartWork(ctx);
       }
       std::swap(result, invocation_results_.front());
       invocation_results_.pop_front();
     }
     cond_var_.notify_all();
+    StopWork(ctx);
     result->notification.WaitForNotification();
+    StartWork(ctx);
     return ProcessResult(result, out_tensors, end_of_sequence);
   }
 
@@ -86,9 +94,8 @@
     }
     CHECK_EQ(num_calls_, 0);
     TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
-    TF_RETURN_IF_ERROR(
-        writer->WriteScalar(full_name("invocation_results.size"),
-                            invocation_results_.size()));
+    TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+                                           invocation_results_.size()));
     for (size_t i = 0; i < invocation_results_.size(); i++) {
       std::shared_ptr<InvocationResult> result = invocation_results_[i];
       TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
@@ -225,6 +232,8 @@
   }
 
   void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+    StartWork(ctx.get());
+    auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
     std::vector<std::shared_ptr<InvocationResult>> new_calls;
     new_calls.reserve(num_parallel_calls_);
     while (true) {
@@ -233,7 +242,9 @@
         while (!cancelled_ &&
                (num_calls_ >= num_parallel_calls_ ||
                 invocation_results_.size() >= MaxInvocationResults())) {
+          StopWork(ctx.get());
           cond_var_.wait(l);
+          StartWork(ctx.get());
         }
         if (cancelled_) {
           return;
@@ -333,4 +344,5 @@
                               std::move(map_func), num_parallel_calls));
 }
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 7e6cc58..dc26c5c 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/dataset.h"
 
 namespace tensorflow {
+namespace data {
 
 // A function that transforms elements of one dataset into another
 // asynchronously. The arguments are:
@@ -47,6 +48,7 @@
     const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
     int32 num_parallel_calls);
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 9057800..0cf5db0 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -20,7 +20,7 @@
 #include "tensorflow/core/util/example_proto_fast_parsing.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -368,5 +368,5 @@
                         ParseExampleDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index b3272f6..533d0bd 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/kernels/data/prefetch_autotuner.h"
 
 namespace tensorflow {
+namespace data {
 
 PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
     : buffer_limit_(initial_buffer_size) {
@@ -43,4 +44,5 @@
   }
 }
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h
index fa8a184..8693205 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.h
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.h
@@ -19,6 +19,7 @@
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
+namespace data {
 
 // PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator.
 //
@@ -66,6 +67,7 @@
   Mode mode_ = Mode::kDisabled;
 };
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
index 29a8cc5..cfc324f 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 TEST(PrefetchAutotuner, Disabled) {
@@ -79,4 +80,5 @@
 }
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 50efbcb..52c421c 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -12,15 +12,19 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include <deque>
-
 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
 
+#include <deque>
+
 #include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/str_util.h"
 
 namespace tensorflow {
+namespace data {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
 // description of the following op.
@@ -70,7 +74,11 @@
    public:
     explicit Iterator(const Params& params)
         : DatasetIterator<Dataset>(params),
-          auto_tuner_(params.dataset->buffer_size_) {}
+          auto_tuner_(params.dataset->buffer_size_) {
+      std::vector<string> components =
+          str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+      prefix_end_ = components.back();
+    }
 
     ~Iterator() override {
       // Signal the prefetch thread to terminate it. We will then
@@ -97,13 +105,16 @@
                            bool* end_of_sequence) override {
       {
         mutex_lock l(mu_);
+        auto stats_aggregator = ctx->stats_aggregator();
         TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
         // Wait until the next element in the buffer has been
         // produced, or we are shutting down.
         while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
                auto_tuner_.buffer_limit() != 0) {
           auto_tuner_.RecordEmpty();
+          StopWork(ctx);
           cond_var_.wait(l);
+          StartWork(ctx);
         }
 
         if (cancelled_) {
@@ -112,7 +123,7 @@
         }
 
         if (!buffer_.empty()) {
-          return Consume(out_tensors, end_of_sequence);
+          return Consume(out_tensors, end_of_sequence, stats_aggregator);
         }
 
         if (prefetch_thread_finished_) {
@@ -200,14 +211,22 @@
       std::vector<Tensor> value;
     };
 
-    Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+    Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence,
+                   const std::shared_ptr<StatsAggregator>& stats_aggregator)
         EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+      if (stats_aggregator) {
+        stats_aggregator->AddToHistogram(
+            strings::StrCat(prefix_end_, "::buffer_utilization"),
+            {static_cast<float>(buffer_.size()) /
+             static_cast<float>(auto_tuner_.buffer_limit())});
+      }
       // A new element is available. Forward the status from computing it, and
       // (if we successfully got an element) the output values.
       Status s = buffer_.front().status;
       if (s.ok()) {
         *out_tensors = std::move(buffer_.front().value);
       }
+      auto_tuner_.RecordConsumption(buffer_.size());
       buffer_.pop_front();
       *end_of_sequence = false;
 
@@ -223,10 +242,10 @@
     Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
         EXCLUSIVE_LOCKS_REQUIRED(mu_) {
       if (!prefetch_thread_) {
-        prefetch_thread_.reset(
-            ctx->env()->StartThread({}, "prefetch_thread",
-                                    std::bind(&Iterator::PrefetchThread, this,
-                                              new IteratorContext(*ctx))));
+        std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+        prefetch_thread_.reset(ctx->env()->StartThread(
+            {}, "prefetch_thread",
+            [this, new_ctx]() { PrefetchThread(new_ctx); }));
       }
       return Status::OK();
     }
@@ -235,8 +254,9 @@
     // buffer.
     //
     // It owns the iterator context passed to it.
-    void PrefetchThread(IteratorContext* ctx) {
-      std::unique_ptr<IteratorContext> cleanup(ctx);
+    void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+      StartWork(ctx.get());
+      auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
       while (true) {
         std::vector<Tensor> value;
 
@@ -244,7 +264,9 @@
         {
           mutex_lock l(mu_);
           while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+            StopWork(ctx.get());
             cond_var_.wait(l);
+            StartWork(ctx.get());
           }
 
           if (cancelled_) {
@@ -261,8 +283,8 @@
         mutex_lock parent_l(parent_mu_);
         bool end_of_sequence;
         BufferElement buffer_element;
-        buffer_element.status =
-            input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+        buffer_element.status = input_impl_->GetNext(
+            ctx.get(), &buffer_element.value, &end_of_sequence);
         if (buffer_element.status.ok() && end_of_sequence) {
           mutex_lock l(mu_);
           prefetch_thread_finished_ = true;
@@ -324,6 +346,7 @@
     mutex parent_mu_ ACQUIRED_BEFORE(mu_);
     std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
     condition_variable cond_var_;
+    string prefix_end_;
     PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
     std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
     std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
@@ -346,6 +369,7 @@
   *output = new Dataset(ctx, input, buffer_size);
 }
 
+namespace {
 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
                         PrefetchDatasetOp);
 REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
@@ -354,4 +378,7 @@
                             .HostMemory("input_dataset")
                             .HostMemory("handle"),
                         PrefetchDatasetOp);
+}  // namespace
+
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
index c40c4b0..588fb25 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.h
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -20,6 +20,7 @@
 #include "tensorflow/core/kernels/data/prefetch_autotuner.h"
 
 namespace tensorflow {
+namespace data {
 
 class PrefetchDatasetOp : public UnaryDatasetOpKernel {
  public:
@@ -34,6 +35,7 @@
   class Dataset;
 };
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 7817170..044a791 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/lib/random/random_distributions.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -151,5 +151,5 @@
                         RandomDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index aa38775..89fbaae 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -142,5 +142,5 @@
                         RangeDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 086b552..c474cb4 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -23,7 +23,7 @@
 #include "tensorflow/core/lib/io/zlib_inputstream.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -691,5 +691,5 @@
                         TFRecordDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 299949b..94e9663 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -250,5 +250,5 @@
                         RepeatDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index fccad93..6e515d6 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -23,7 +23,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -279,5 +279,5 @@
 REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 93a4376..66466d6 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -25,7 +25,7 @@
 #include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 const int64 kLogIntervalMicros = 10 * 1000000;  // 10 seconds.
@@ -620,5 +620,5 @@
                         ShuffleAndRepeatDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
new file mode 100644
index 0000000..5b084a1
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -0,0 +1,380 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class SingleThreadedExecutorImpl : public Executor {
+ public:
+  explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
+      : params_(params) {}
+
+  ~SingleThreadedExecutorImpl() override {
+    for (const KernelState& kernel_state : kernels_) {
+      params_.delete_kernel(kernel_state.kernel);
+    }
+  }
+
+  Status Initialize(const Graph& graph) {
+    // Topologicially sort `graph` to get a sequence of OpKernels.
+    std::vector<Node*> ordered_nodes;
+    ordered_nodes.reserve(graph.num_nodes());
+    GetReversePostOrder(graph, &ordered_nodes);
+
+    if (ordered_nodes.size() != graph.num_nodes()) {
+      return errors::InvalidArgument("Graph had ", graph.num_nodes(),
+                                     " but reverse post-order had ",
+                                     ordered_nodes.size());
+    }
+
+    kernels_.resize(ordered_nodes.size());
+
+    std::unordered_map<Node*, size_t> node_to_index_map;
+
+    // Create the kernel and input-related structures for each node in `graph`.
+    for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+      Node* n = ordered_nodes[i];
+      node_to_index_map[n] = i;
+
+      for (DataType dt : n->output_types()) {
+        if (IsRefType(dt)) {
+          return errors::Unimplemented(
+              "Single-threaded executor does not support reference-typed "
+              "edges.");
+        }
+      }
+
+      if (n->IsControlFlow()) {
+        return errors::Unimplemented(
+            "Single-threaded executor does not support control flow.");
+      }
+      if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
+        return errors::Unimplemented(
+            "Single-threaded executor does not support partitioned graphs.");
+      }
+      if (n->IsCollective()) {
+        return errors::Unimplemented(
+            "Single-threaded executor does not support collective ops.");
+      }
+
+      KernelState& kernel_state = kernels_[i];
+      TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
+      kernel_state.num_inputs = n->num_inputs();
+      kernel_state.num_outputs = n->num_outputs();
+
+      if (i == 0) {
+        kernel_state.input_start_index = 0;
+      } else {
+        const KernelState& previous_kernel_state = kernels_[i - 1];
+        kernel_state.input_start_index =
+            previous_kernel_state.input_start_index +
+            previous_kernel_state.num_inputs;
+      }
+    }
+
+    // Build the mapping from each node output to the input slot for the
+    // corresponding destination node.
+    for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+      Node* n = ordered_nodes[i];
+      KernelState& kernel_state = kernels_[i];
+      kernel_state.output_locations.resize(kernel_state.num_outputs);
+      for (const Edge* e : n->out_edges()) {
+        if (!e->IsControlEdge()) {
+          kernel_state.output_locations[e->src_output()].push_back(
+              kernels_[node_to_index_map[e->dst()]].input_start_index +
+              e->dst_input());
+        }
+      }
+
+      // Compute allocator attributes for each node output, and corresponding
+      // node input.
+      kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
+      AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
+
+      OpKernel* op_kernel = kernel_state.kernel;
+      for (int out = 0; out < n->num_outputs(); out++) {
+        DCHECK_LT(out, op_kernel->output_memory_types().size());
+        bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
+        if (on_host) {
+          AllocatorAttributes h;
+          h.set_on_host(on_host);
+          attrs[out].Merge(h);
+        }
+      }
+    }
+
+    if (!kernels_.empty()) {
+      const KernelState& last_kernel_state = kernels_.back();
+      total_num_inputs_ =
+          last_kernel_state.input_start_index + last_kernel_state.num_inputs;
+      input_alloc_attrs_.resize(total_num_inputs_);
+      for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+        for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
+          for (size_t output_location : kernels_[i].output_locations[j]) {
+            input_alloc_attrs_[output_location] =
+                kernels_[i].output_alloc_attrs[j];
+          }
+        }
+      }
+    } else {
+      total_num_inputs_ = 0;
+    }
+    return Status::OK();
+  }
+
+  // TODO(mrry): Consider specializing the implementation of Executor::Run()
+  // instead, to avoid unnecessary atomic operations in the callback when
+  // running synchronously.
+  void RunAsync(const Args& args, DoneCallback done) override {
+    // The inputs to each kernel are stored contiguously in `inputs`.
+    //
+    // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
+    // determine the range of elements in this vector that correspond to
+    // the inputs of `kernels_[i]`.
+    //
+    // This vector has the following layout:
+    //
+    // * Kernel 0, input 0.
+    // * Kernel 0, input 1.
+    // * ...
+    // * Kernel 0, input `kernels_[0].num_inputs - 1`.
+    // * Kernel 1, input 0.
+    // * ...
+    // * Kernel 1, input `kernels_[1].num_inputs - 1`.
+    // * ...
+    // * Kernel `kernels_.size() - 1`, input 0.
+    // * ...
+    // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
+    //
+    // Note that kernels with zero inputs do not correspond to any elements in
+    // this vector.
+    //
+    // We use `ManualConstructor<Tensor>` to avoid the overhead of
+    // default-constructing an invalid `Tensor` for each slot at the beginning
+    // of execution:
+    // * Elements are initialized when the outputs of a kernel execution are
+    //   propagated to the inputs of kernels that depend on them.
+    // * The elements corresponding to the inputs for kernel `i` are destroyed
+    //   after kernel `i` executes.
+    // * In an error case (see below), we use the connectivity information in
+    //   `KernelState::output_locations` to determine which locations have been
+    //   initialized, and manually destroy them.
+    std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_);
+
+    // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
+    // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
+    TensorValueVec node_inputs;
+    DeviceContextVec input_device_contexts;
+    AllocatorAttributeVec input_alloc_attrs;
+
+    // Prepare the parameters that will be the same for all kernels.
+    OpKernelContext::Params params;
+    params.step_id = args.step_id;
+    Device* device = params_.device;
+    params.device = device;
+    params.log_memory = false;              // TODO(mrry): Too severe?
+    params.record_tensor_accesses = false;  // TODO(mrry): Too severe?
+    params.rendezvous = args.rendezvous;
+    params.session_state = args.session_state;
+    params.tensor_store = args.tensor_store;
+    params.cancellation_manager = args.cancellation_manager;
+    // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor
+    // allocations that it performs. Consider specializing its handling in the
+    // executor.
+    params.call_frame = args.call_frame;
+    params.function_library = params_.function_library;
+    params.resource_manager = device->resource_manager();
+    params.step_container = args.step_container;
+    params.slice_reader_cache = nullptr;  // TODO(mrry): Too severe?
+    params.inputs = &node_inputs;
+    params.input_device_contexts = &input_device_contexts;
+    params.input_alloc_attrs = &input_alloc_attrs;
+
+    Args::Runner runner_copy = args.runner;
+    params.runner = &runner_copy;
+    params.stats_collector = args.stats_collector;
+
+    // NOTE(mrry): We are assuming that the graph is loopless and condless.
+    params.frame_iter = FrameAndIter(0, 0);
+    params.is_input_dead = false;
+
+    // TODO(mrry): Add non-default device context inference.
+    params.op_device_context = nullptr;
+    // TODO(mrry): Consider implementing forwarding.
+    params.forward_from_array = nullptr;
+
+    // Execute the kernels one-at-a-time in topological order.
+    for (size_t i = 0; i < kernels_.size(); ++i) {
+      const KernelState& kernel_state = kernels_[i];
+
+      // Prepare the per-kernel parameters.
+      const size_t input_start_index = kernel_state.input_start_index;
+      const size_t num_inputs = kernel_state.num_inputs;
+      const size_t num_outputs = kernel_state.num_outputs;
+
+      node_inputs.clear();
+      node_inputs.resize(num_inputs);
+      input_alloc_attrs.clear();
+      input_alloc_attrs.resize(num_inputs);
+      for (size_t j = 0; j < num_inputs; ++j) {
+        auto t = inputs[input_start_index + j].get();
+        node_inputs[j].tensor = t;
+        input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
+      }
+      params.op_kernel = kernel_state.kernel;
+      input_device_contexts.clear();
+      input_device_contexts.resize(num_inputs);
+      params.output_attr_array = kernel_state.output_alloc_attrs.data();
+      OpKernelContext ctx(&params, num_outputs);
+
+      // Actually execute the kernel.
+      device->Compute(kernel_state.kernel, &ctx);
+
+      if (!ctx.status().ok()) {
+        // On failure, we must manually free all intermediate tensors. We have
+        // already freed all the inputs for kernels up to (but not including)
+        // the `i`th kernel. We scan through the previously executed kernels and
+        // destroy any tensors that were destined to be the input for a kernel
+        // that has not yet executed.
+        for (size_t j = 0; j < i; ++j) {
+          const KernelState& executed_kernel_state = kernels_[j];
+          for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) {
+            for (size_t output_location :
+                 executed_kernel_state.output_locations[k]) {
+              if (output_location >= input_start_index) {
+                // Only destroy an output location if it is an input to an
+                // operation that has not yet executed.
+                inputs[output_location].Destroy();
+              }
+            }
+          }
+        }
+        done(ctx.status());
+        return;
+      }
+
+      // Free the inputs to the current kernel.
+      for (size_t j = 0; j < num_inputs; ++j) {
+        inputs[input_start_index + j].Destroy();
+      }
+
+      // Forward the outputs of the kernel to the inputs of subsequent kernels.
+      for (size_t j = 0; j < num_outputs; ++j) {
+        TensorValue val = ctx.release_output(j);
+        // TODO(mrry): Consider flattening the `output_locations` vector
+        // to improve the cache-friendliness of this loop.
+        for (size_t output_location : kernel_state.output_locations[j]) {
+          // TODO(mrry): Validate that the types match the expected values or
+          // ensure that the necessary validation has already happened.
+          inputs[output_location].Init(*val.tensor);
+        }
+        delete val.tensor;
+      }
+    }
+    done(Status::OK());
+  }
+
+ private:
+  const LocalExecutorParams params_;
+
+  // All following members are read-only after Initialize().
+
+  // The sum of the number of inputs for each node in the graph. This determines
+  // the length of the flat `inputs` vector. See comment at the beginning of
+  // `RunAsync()` for details.
+  size_t total_num_inputs_;
+
+  // Represents cached graph structure state for each kernel.
+  struct KernelState {
+    // The kernel object. Not owned.
+    //
+    // This pointer is managed by `params_.create_kernel()` and
+    // `params_.delete_kernel()`.
+    OpKernel* kernel;
+
+    // These fields determine the range of elements in `inputs` that corresponds
+    // to the inputs of `kernel`.
+    size_t input_start_index;
+    size_t num_inputs;
+
+    size_t num_outputs;
+
+    // For the `j`th output of `kernel`, `output_locations[j]` contains the
+    // locations in the flat `inputs` vector to which that output must be
+    // copied. See comment at the beginning of `RunAsync()` for details.
+    std::vector<std::vector<size_t>>
+        output_locations;  // Length = `num_outputs`.
+
+    // Memory space information for each output of `kernel`.
+    std::vector<AllocatorAttributes>
+        output_alloc_attrs;  // Length = `num_outputs`.
+  };
+  std::vector<KernelState> kernels_;
+
+  // Memory space information for each input. This information is stored in the
+  // same order as the flat `inputs` vector. See comment at the beginning of
+  // `RunAsync()` for details.
+  std::vector<AllocatorAttributes>
+      input_alloc_attrs_;  // Length = `total_num_inputs_`.
+};
+
+class SingleThreadedExecutorRegistrar {
+ public:
+  SingleThreadedExecutorRegistrar() {
+    ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory());
+  }
+
+ private:
+  class Factory : public ExecutorFactory {
+    Status NewExecutor(const LocalExecutorParams& params,
+                       std::unique_ptr<const Graph> graph,
+                       std::unique_ptr<Executor>* out_executor) override {
+      Executor* ret;
+      TF_RETURN_IF_ERROR(
+          NewSingleThreadedExecutor(params, std::move(graph), &ret));
+      out_executor->reset(ret);
+      return Status::OK();
+    }
+  };
+};
+static SingleThreadedExecutorRegistrar registrar;
+
+}  // namespace
+
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+                                 std::unique_ptr<const Graph> graph,
+                                 Executor** executor) {
+  std::unique_ptr<SingleThreadedExecutorImpl> impl(
+      new SingleThreadedExecutorImpl(params));
+  TF_RETURN_IF_ERROR(impl->Initialize(*graph));
+  *executor = impl.release();
+  return Status::OK();
+}
+
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h
new file mode 100644
index 0000000..e934352
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/executor.h"
+
+namespace tensorflow {
+namespace data {
+
+// Creates a new `Executor` for executing `graph` synchronously on the caller
+// thread.
+//
+// NOTE(mrry): The returned executor is optimized to impose low overhead on
+// graphs that perform a small amount of work (e.g. <15us of work per graph on
+// present architectures). It eschews concurrency, because issuing work to
+// multiple threads can dominate the cost of executing small ops synchronously,
+// and because contention in the executor data structures can reduce throughput
+// (in terms of ops executed per unit time).
+//
+// However, the current implementation has the following limitations:
+//
+// 1. Reference-typed tensors are not supported and will not be supported in
+//    future.
+// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not
+//    currently supported. The current plan is to extend support to "functional"
+//    control flow after the TensorFlow APIs transition to building graphs in
+//    that form (e.g. `tf.cond_v2()`).
+// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported.
+//    The present implementation executes kernels one at a time in topological
+//    order, and cannot currently distinguish between disconnected subgraphs
+//    that are logically connected by subgraphs on a different device.
+// 4. Memory logging is not currently supported.
+// 5. Allocation forwarding is not currently supported.
+// 6. Non-default device contexts are not currently supported. In effect, this
+//    limits the executor to CPU devices.
+// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null
+//    are not currently supported.
+//
+// The single-threaded executor is primarily suitable for executing simple
+// TensorFlow functions, such as one might find in a `tf.data` pipeline.
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+                                 std::unique_ptr<const Graph> graph,
+                                 Executor** executor);
+
+}  // namespace data
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
new file mode 100644
index 0000000..6244e28
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+  ExecutorTest()
+      : device_(DeviceFactory::NewDevice("CPU", {},
+                                         "/job:localhost/replica:0/task:0")) {}
+
+  ~ExecutorTest() override {
+    // There should always be exactly one Ref left on the Rendezvous
+    // when the test completes.
+    CHECK(rendez_->Unref());
+    delete exec_;
+    delete device_;
+  }
+
+  // Resets executor_ with a new executor based on a graph 'gdef'.
+  void Create(std::unique_ptr<const Graph> graph) {
+    const int version = graph->versions().producer();
+    LocalExecutorParams params;
+    params.device = device_;
+    params.create_kernel = [this, version](const NodeDef& ndef,
+                                           OpKernel** kernel) {
+      return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+    };
+    params.delete_kernel = [](OpKernel* kernel) {
+      DeleteNonCachedKernel(kernel);
+    };
+    delete exec_;
+    TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
+    runner_ = [](std::function<void()> fn) { fn(); };
+    rendez_ = NewLocalRendezvous();
+  }
+
+  Status Run(Rendezvous* rendez) {
+    Executor::Args args;
+    args.rendezvous = rendez;
+    args.runner = runner_;
+    return exec_->Run(args);
+  }
+
+  Status Run(CallFrameInterface* call_frame) {
+    Executor::Args args;
+    args.call_frame = call_frame;
+    args.runner = runner_;
+    return exec_->Run(args);
+  }
+
+  Device* device_ = nullptr;
+  Executor* exec_ = nullptr;
+  Executor::Args::Runner runner_;
+  Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+  Tensor tensor(DT_FLOAT, TensorShape({}));
+  tensor.scalar<float>()() = val;
+  return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+  Tensor tensor(DT_INT32, TensorShape({}));
+  tensor.scalar<int32>()() = val;
+  return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+  Tensor tensor(DT_BOOL, TensorShape({}));
+  tensor.scalar<bool>()() = val;
+  return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+  Tensor tensor(DT_DOUBLE, TensorShape({}));
+  tensor.scalar<double>()() = val;
+  return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+  CHECK_EQ(tensor.dtype(), DT_FLOAT);
+  CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+  return tensor.scalar<float>()();
+}
+
+Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
+                          const string& receiver, const string& name) {
+  Rendezvous::ParsedKey result;
+  TF_CHECK_OK(
+      Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
+                                                 name, FrameAndIter(0, 0)),
+                           &result));
+  return result;
+}
+
+TEST_F(ExecutorTest, SimpleAdd) {
+  // c = a + b
+  std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+  auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+  auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+  auto tmp = test::graph::Add(g.get(), in0, in1);
+  test::graph::Retval(g.get(), 0, tmp);
+  FixupSourceAndSinkEdges(g.get());
+  Create(std::move(g));
+  FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+  TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)}));
+  TF_ASSERT_OK(Run(&call_frame));
+  std::vector<Tensor> retvals;
+  TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+  EXPECT_EQ(2.0, V(retvals[0]));  // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+  // v0 <- a
+  // v1 = v0 + v0
+  // v2 = v1 + v1
+  // ... ...
+  // v10 = v9 + v9
+  //
+  // b <- v10
+  // All nodes are executed by one thread.
+  std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+  auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
+  const int N = 10;
+  for (int i = 1; i <= N; ++i) {
+    v = test::graph::Add(g.get(), v, v);
+  }
+  // out <- v10
+  test::graph::Retval(g.get(), 0, v);
+  FixupSourceAndSinkEdges(g.get());
+  Create(std::move(g));
+  FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+  // a = 1.0
+  TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+  TF_ASSERT_OK(Run(&call_frame));
+  std::vector<Tensor> retvals;
+  TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+  EXPECT_EQ(1024.0, V(retvals[0]));  // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+//     a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+//     a + ((a + a) + a)
+//     (a + a) + (a + a)
+//     ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+  CHECK_GT(N, 1);
+  // A single input node "in".
+  auto in = test::graph::Arg(g, 0, DT_FLOAT);
+  std::vector<Node*> nodes;
+  int i = 0;
+  // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+  for (; i < N; ++i) {
+    nodes.push_back(test::graph::Identity(g, in, 0));
+  }
+  random::PhiloxRandom philox(0, 17);
+  random::SimplePhilox rnd(&philox);
+  while (nodes.size() > 1) {
+    // Randomly pick two from nodes and add them. The resulting node
+    // is named lik n10, n11, .... and is put back into "nodes".
+    int x = rnd.Uniform(nodes.size());
+    auto in0 = nodes[x];
+    nodes[x] = nodes.back();
+    nodes.resize(nodes.size() - 1);
+    x = rnd.Uniform(nodes.size());
+    auto in1 = nodes[x];
+    // node = in0 + in1.
+    nodes[x] = test::graph::Add(g, in0, in1);
+  }
+  // The final output node "out".
+  test::graph::Retval(g, 0, nodes.back());
+  FixupSourceAndSinkEdges(g);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+  std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+  BuildTree(4096, g.get());
+  Create(std::move(g));
+  FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+  TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+  TF_ASSERT_OK(Run(&call_frame));
+  std::vector<Tensor> retvals;
+  TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+  EXPECT_EQ(4096.0, V(retvals[0]));
+}
+
+TEST_F(ExecutorTest, OpError) {
+  std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+  auto zero = test::graph::Constant(g.get(), V(0.0));
+  auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
+  auto check = test::graph::CheckNumerics(g.get(), inf, "message");
+  auto two = test::graph::Constant(g.get(), V(2.0));
+  test::graph::Binary(g.get(), "Mul", check, two);
+  FixupSourceAndSinkEdges(g.get());
+  Create(std::move(g));
+  FunctionCallFrame call_frame({}, {});
+  // Fails due to invalid dtype.
+  EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
+}
+
+static void BM_executor(int iters, int width, int depth) {
+#ifdef PLATFORM_GOOGLE
+  BenchmarkUseRealTime();
+#endif  // PLATFORM_GOOGLE
+  Graph* g = new Graph(OpRegistry::Global());
+  random::PhiloxRandom philox(1729, 17);
+  random::SimplePhilox rand(&philox);
+  uint64 cur = 0;
+  uint32 r = 1 + rand.Rand32() % width;
+  std::vector<Node*> ready_nodes;
+  for (int i = 0; i < r; ++i) {
+    ready_nodes.push_back(test::graph::NoOp(g, {}));
+    ++cur;
+  }
+  for (int i = 0; i < depth; ++i) {
+    std::random_shuffle(ready_nodes.begin(), ready_nodes.end());
+    r = 1 + rand.Rand32() % (ready_nodes.size());
+    std::vector<Node*> control_inputs;
+    for (int j = 0; j < r; ++j) {
+      control_inputs.push_back(ready_nodes.back());
+      ready_nodes.pop_back();
+    }
+    Node* n = test::graph::NoOp(g, control_inputs);
+    ++cur;
+    r = 1 + rand.Rand32() % width;
+    for (int j = 0; j < r; ++j) {
+      ready_nodes.push_back(test::graph::NoOp(g, {n}));
+      ++cur;
+    }
+  }
+  FixupSourceAndSinkEdges(g);
+#ifdef PLATFORM_GOOGLE
+  SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
+  SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
+#endif  // PLATFORM_GOOGLE
+  test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+                  "SINGLE_THREADED_EXECUTOR")
+      .Run(iters);
+}
+
+// Tall skinny graphs
+BENCHMARK(BM_executor)->ArgPair(16, 1024);
+BENCHMARK(BM_executor)->ArgPair(32, 8192);
+
+// Short fat graphs
+BENCHMARK(BM_executor)->ArgPair(1024, 16);
+BENCHMARK(BM_executor)->ArgPair(8192, 32);
+
+// Tall fat graph
+BENCHMARK(BM_executor)->ArgPair(1024, 1024);
+
+// TODO(mrry): This benchmark currently crashes with a use-after free, because
+// test::Benchmark::RunWithArgs() assumes that the executor will take ownership
+// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
+// duration of the benchmark. Since the single threaded executor does not retain
+// a copy of the graph, this fails.
+//
+// TODO(mrry): Add support for Arg/Retval "function call convention" in
+// `test::Benchmark::RunWithArgs()`.
+#if 0
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+static void BM_FeedInputFetchOutput(int iters) {
+  Graph* g = new Graph(OpRegistry::Global());
+  // z = x + y: x and y are provided as benchmark inputs.  z is the
+  // output of the benchmark.  Conceptually, the caller is ALICE, the
+  // benchmark is BOB.
+  Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+  Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
+  Node* sum = test::graph::Add(g, x, y);
+  Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
+  FixupSourceAndSinkEdges(g);
+  Tensor val(DT_FLOAT, TensorShape({}));
+  val.scalar<float>()() = 3.14;
+  SetBenchmarkItemsProcessed(static_cast<int64>(iters));
+  test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+                  "SINGLE_THREADED_EXECUTOR")
+      .RunWithArgs({{x, val}, {y, val}}, {z}, iters);
+}
+BENCHMARK(BM_FeedInputFetchOutput);
+#endif
+
+}  // namespace
+}  // namespace data
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index fe7ef38..b8c7fb1 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -187,5 +187,5 @@
 REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 14df3a6..1e73cfc 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -23,7 +23,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -293,5 +293,5 @@
                         SlideDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index e526578..85b1e50 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -21,7 +21,7 @@
 #include "tensorflow/core/util/sparse/sparse_tensor.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -274,5 +274,5 @@
 #undef REGISTER_DATASET_KERNEL
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc
index ffabda1..783d1e6 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.cc
+++ b/tensorflow/core/kernels/data/sql/driver_manager.cc
@@ -16,7 +16,7 @@
 #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace sql {
 
 std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
@@ -30,5 +30,5 @@
 }
 
 }  // namespace sql
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h
index a34691b..c5428f3 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/sql/driver_manager.h
@@ -18,7 +18,7 @@
 #include "tensorflow/core/kernels/data/sql/query_connection.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace sql {
 
 // A factory class for creating `QueryConnection` instances.
@@ -35,7 +35,7 @@
 };
 
 }  // namespace sql
-
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h
index e9ffca2..2fd229a 100644
--- a/tensorflow/core/kernels/data/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/sql/query_connection.h
@@ -18,6 +18,7 @@
 #include "tensorflow/core/framework/tensor.h"
 
 namespace tensorflow {
+namespace data {
 
 class IteratorContext;
 
@@ -63,7 +64,7 @@
 };
 
 }  // namespace sql
-
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index 7cd07bd..5108e83 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -19,7 +19,7 @@
 #include "tensorflow/core/lib/strings/stringprintf.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace sql {
 
 SqliteQueryConnection::SqliteQueryConnection() {}
@@ -115,5 +115,5 @@
 }
 
 }  // namespace sql
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
index 81b1953..175492c 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
@@ -22,7 +22,7 @@
 #include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace sql {
 
 class SqliteQueryConnection : public QueryConnection {
@@ -50,7 +50,7 @@
 };
 
 }  // namespace sql
-
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 2aa153f..6bbe459 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -24,8 +24,9 @@
 #include "tensorflow/core/lib/strings/stringprintf.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
+
 // See documentation in ../ops/dataset_ops.cc for a high-level
 // description of the following ops.
 
@@ -211,5 +212,5 @@
 REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 75af73d..f5314f7 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -19,6 +19,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
@@ -135,4 +136,5 @@
 REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
                         SetStatsAggregatorDatasetOp);
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index b133cfa..a7ded67 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -26,6 +26,7 @@
 #include "tensorflow/core/platform/macros.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 static mutex* get_counters_map_lock() {
@@ -145,4 +146,5 @@
                         StatsAggregatorSummaryOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 8957f5d..e9e42f0 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/core/lib/random/random.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 // This op defines a `Dataset` that passes through its input elements and
@@ -248,4 +249,5 @@
                         BytesProducedStatsDatasetOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index e5c237d..e5cdfdd 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -174,5 +174,5 @@
 REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 1192faf..e1cefd2 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -140,5 +140,5 @@
                         TensorDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index ccd5e60..2ed636a 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -24,7 +24,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a,
@@ -648,5 +648,5 @@
                         EnqueueInQueueDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index dc32cd2..7dc64b0 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -19,7 +19,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -168,5 +168,5 @@
                         TensorSliceDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 1a79f72..81c432b 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -18,7 +18,7 @@
 #include "tensorflow/core/util/batch_util.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -204,5 +204,5 @@
                         UnbatchDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 0ab6bea..2ad4711 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
+namespace data {
 namespace {
 
 class WindowDataset : public DatasetBase {
@@ -107,4 +108,5 @@
   return Status::OK();
 }
 
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 7bd31a0..84cb3c7 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -23,6 +23,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
+namespace data {
 
 // Creates a dataset representing an eagerly-collected window of elements.
 //
@@ -43,6 +44,7 @@
                         std::vector<PartialTensorShape> output_shapes,
                         DatasetBase** out_dataset);
 
+}  // namespace data
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 41bf9d4..3975086 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -19,7 +19,7 @@
 #include "tensorflow/core/kernels/data/window_dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -195,5 +195,5 @@
                         WindowDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 1c49874..3f76695 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -22,7 +22,7 @@
 #include "tensorflow/core/platform/file_system.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 class ToTFRecordOp : public AsyncOpKernel {
@@ -104,4 +104,5 @@
                         ToTFRecordOp);
 
 }  // namespace
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index e430657..61a2078 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -17,7 +17,7 @@
 #include "tensorflow/core/kernels/data/dataset.h"
 
 namespace tensorflow {
-
+namespace data {
 namespace {
 
 // See documentation in ../ops/dataset_ops.cc for a high-level
@@ -175,5 +175,5 @@
 REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp);
 
 }  // namespace
-
+}  // namespace data
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 33ed552..d705e82 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -255,7 +255,7 @@
     TensorShape shape({1});
     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
     output_tensor->vec<int64>()(0) = nan_count;
-    PublishTensor(*output_tensor);
+    OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
   }
 };
 
@@ -380,7 +380,7 @@
     bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
                 positive_inf_count == 0;
     if (!mute) {
-      PublishTensor(*output_tensor);
+      OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
     }
   }
 
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a..750efca 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -108,8 +108,7 @@
     const int32 abs_height = abs(height);
 
     // there may be padding bytes when the width is not a multiple of 4 bytes
-    // 8 * channels == bits per pixel
-    const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+    const int row_size = (channels_ * width + 3) / 4 * 4;
 
     const int64 last_pixel_offset = static_cast<int64>(header_size) +
                                     (abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847..6bfb5bd 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@
     OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
 
     for (int i = 0; i < record_defaults.size(); ++i) {
+      OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+                  errors::InvalidArgument(
+                      "Each record default should be at most rank 1"));
       OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
                   errors::InvalidArgument(
                       "There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index b01db91..fb2a4cc 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -247,8 +247,8 @@
             data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
 
         if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
-          T* merged_base = &merged_flat(0, 0);
-          const T* data_base = &data_flat(0, 0);
+          T* merged_base = merged_flat.data();
+          const T* data_base = data_flat.data();
           for (int i = 0; i < indices_vec.size(); i++) {
             int32 index = internal::SubtleMustCopy(indices_vec(i));
             OP_REQUIRES(
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index e13e548..8edf7d4 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -51,48 +51,55 @@
                      internal::traits<OutputBackward>::NumDimensions>,
         const TensorContractionOp<
             const array<
-                IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+                IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
+            const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+                const DSizes<typename internal::traits<OutputBackward>::Index,
+                             2>,
+                const TensorShufflingOp<
+                    const array<
+                        typename internal::traits<OutputBackward>::Index, 5>,
+                    const TensorReverseOp<const Eigen::array<bool, 5>,
+                                          const Kernel>>>>,
             const TensorReshapingOp<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
-                             3>,
-                const TensorReverseOp<const array<bool, 5>, const Kernel> >,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<OutputBackward>::Index,
-                             3>,
+                             2>,
                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const OutputBackward> > > >,
+                                          const OutputBackward>>>>,
     TensorReshapingOp<
         const DSizes<typename internal::traits<OutputBackward>::Index,
                      internal::traits<OutputBackward>::NumDimensions>,
         const TensorContractionOp<
             const array<
-                IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+                IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
             const TensorReshapingOp<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
-                             3>,
+                             2>,
                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                          const OutputBackward> >,
-            const TensorReshapingOp<
+                                          const OutputBackward>>,
+            const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
                 const DSizes<typename internal::traits<OutputBackward>::Index,
-                             3>,
-                const TensorReverseOp<const array<bool, 5>,
-                                      const Kernel> > > > >::type
+                             2>,
+                const TensorShufflingOp<
+                    const array<
+                        typename internal::traits<OutputBackward>::Index, 5>,
+                    const TensorReverseOp<const Eigen::array<bool, 5>,
+                                          const Kernel>>>>>>>::type
 CuboidConvolutionBackwardInput(
     const Kernel& kernel, const OutputBackward& output_backward,
     typename internal::traits<OutputBackward>::Index inputPlanes,
     typename internal::traits<OutputBackward>::Index inputRows,
     typename internal::traits<OutputBackward>::Index inputCols,
-    const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
-    const DenseIndex strideCols = 1) {
+    const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1,
+    const DenseIndex col_stride = 1) {
   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
   const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
                                internal::traits<Kernel>::NumDimensions,
-                               internal::traits<Kernel>::Layout, TensorIndex> >
+                               internal::traits<Kernel>::Layout, TensorIndex>>
       kern(kernel);
   const TensorRef<
       const Tensor<typename internal::traits<OutputBackward>::Scalar,
                    internal::traits<OutputBackward>::NumDimensions,
-                   internal::traits<OutputBackward>::Layout, TensorIndex> >
+                   internal::traits<OutputBackward>::Layout, TensorIndex>>
       out(output_backward);
 
   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -125,58 +132,45 @@
   const TensorIndex outputCols =
       isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
 
-  TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
-  const TensorIndex size_z =
-      Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
-  const TensorIndex size_y =
-      Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
-  const TensorIndex size_x =
-      Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
+  // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+  // effective kernel planes/rows/cols are always the same as the kernel itself
+  // (see eigen_spatial_convolutions for details).
+  const TensorIndex kernelPlanesEff = kernelPlanes;
+  const TensorIndex kernelRowsEff = kernelRows;
+  const TensorIndex kernelColsEff = kernelCols;
 
-  // Infer padding type.
-  if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
-    // SAME padding.
-    const TensorIndex dz = numext::maxi<TensorIndex>(
-        0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
-    const TensorIndex dy = numext::maxi<TensorIndex>(
-        0, (size_y - 1) * strideRows + kernelRows - inputRows);
-    const TensorIndex dx = numext::maxi<TensorIndex>(
-        0, (size_x - 1) * strideCols + kernelCols - inputCols);
+  // Computing the forward padding.
+  const TensorIndex forward_pad_top_z = numext::maxi<Index>(
+      0,
+      ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2);
+  const TensorIndex forward_pad_top = numext::maxi<Index>(
+      0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
+  const TensorIndex forward_pad_left = numext::maxi<Index>(
+      0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
 
-    forward_pad_z = dz / 2;
-    forward_pad_y = dy / 2;
-    forward_pad_x = dx / 2;
-  } else {
-    // VALID padding.
-    forward_pad_z = 0;
-    forward_pad_y = 0;
-    forward_pad_x = 0;
-  }
-  const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
-  const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
-  const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
+  const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z;
+  const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+  const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
 
-  const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
-                                      (outputPlanes - 1) * stridePlanes - 1 -
-                                      padding_ztop;
-  const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
-                                     (outputRows - 1) * strideRows - 1 -
-                                     padding_top;
-  const TensorIndex padding_right = inputCols + kernelCols - 1 -
-                                    (outputCols - 1) * strideCols - 1 -
-                                    padding_left;
+  const TensorIndex padding_bottom_z = inputPlanes -
+                                       (outputPlanes - 1) * plane_stride - 2 -
+                                       padding_top_z + kernelPlanesEff;
+  const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
+                                     2 - padding_top + kernelRowsEff;
+  const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
+                                    2 - padding_left + kernelColsEff;
 
-  eigen_assert(padding_ztop >= 0);
-  eigen_assert(padding_zbottom >= 0);
+  eigen_assert(padding_top_z >= 0);
   eigen_assert(padding_top >= 0);
   eigen_assert(padding_left >= 0);
+  eigen_assert(padding_bottom_z >= 0);
   eigen_assert(padding_bottom >= 0);
   eigen_assert(padding_right >= 0);
 
-  // The kernel has dimensions filters X channels X patch_planes X patch_rows X
-  // patch_cols.
+  // The kernel has dimensions :
+  //   filters x channels x patch_planes x patch_rows x patch_cols.
   // We need to reverse the kernel along the spatial dimensions.
-  array<bool, 5> kernel_reverse;
+  Eigen::array<bool, 5> kernel_reverse;
   if (isColMajor) {
     kernel_reverse[0] = false;
     kernel_reverse[1] = false;
@@ -191,15 +185,35 @@
     kernel_reverse[4] = false;
   }
 
-  DSizes<TensorIndex, 3> kernel_dims;
+  // Reorder the dimensions to:
+  //   filters x patch_planes x patch_rows x patch_cols x channels
+  array<TensorIndex, 5> kernel_shuffle;
   if (isColMajor) {
-    kernel_dims[0] = kernelFilters;
-    kernel_dims[1] = kernelChannels;
-    kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
+    //  From: filters x channels x planes x rows x cols
+    //  To:   filters x planes x rows x cols x channels
+    kernel_shuffle[0] = 0;
+    kernel_shuffle[1] = 2;
+    kernel_shuffle[2] = 3;
+    kernel_shuffle[3] = 4;
+    kernel_shuffle[4] = 1;
   } else {
-    kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
+    //  From: cols x rows x planes x channels x filters
+    //  To:   channels x cols x rows x planes x filters
+    kernel_shuffle[0] = 3;
+    kernel_shuffle[1] = 0;
+    kernel_shuffle[2] = 1;
+    kernel_shuffle[3] = 2;
+    kernel_shuffle[4] = 4;
+  }
+
+  // Collapse the dims
+  DSizes<TensorIndex, 2> kernel_dims;
+  if (isColMajor) {
+    kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
     kernel_dims[1] = kernelChannels;
-    kernel_dims[2] = kernelFilters;
+  } else {
+    kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
+    kernel_dims[0] = kernelChannels;
   }
 
   // The output_backward has dimensions out_depth X out_planes X out_rows X
@@ -208,36 +222,32 @@
   // dimensions:
   //   out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes *
   //   input_rows * input_cols * OTHERS)
-  DSizes<TensorIndex, 3> pre_contract_dims;
+  DSizes<TensorIndex, 2> pre_contract_dims;
   if (isColMajor) {
-    pre_contract_dims[0] = kernelFilters;
-    pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
-    pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+    pre_contract_dims[0] =
+        kernelFilters * kernelPlanes * kernelRows * kernelCols;
+    pre_contract_dims[1] = inputPlanes * inputRows * inputCols;
     for (int i = 4; i < NumDims; ++i) {
-      pre_contract_dims[2] *= out.dimension(i);
+      pre_contract_dims[1] *= out.dimension(i);
     }
   } else {
-    pre_contract_dims[2] = kernelFilters;
-    pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
-    pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
+    pre_contract_dims[1] =
+        kernelFilters * kernelPlanes * kernelRows * kernelCols;
+    pre_contract_dims[0] = inputPlanes * inputRows * inputCols;
     for (int i = 0; i < NumDims - 4; ++i) {
       pre_contract_dims[0] *= out.dimension(i);
     }
   }
 
-  // We will contract along dimensions (0, 2) in kernel and (0, 1) in
-  // output_backward, if this is col-major, and
-  // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this
-  // row-major.
-  array<IndexPair<TensorIndex>, 2> contract_dims;
+  // We will contract along the collapsed dimension that contains the
+  // kernelFilters, kernelPlanes, kernelRows and kernelCols.
+  array<IndexPair<TensorIndex>, 1> contract_dims;
   if (isColMajor) {
     // col-major: kernel.contract(output.patches)
     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
-    contract_dims[1] = IndexPair<TensorIndex>(2, 1);
   } else {
     // row-major: output.patches.contract(kernel)
-    contract_dims[0] = IndexPair<TensorIndex>(1, 0);
-    contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+    contract_dims[0] = IndexPair<TensorIndex>(1, 1);
   }
 
   // Post contraction, the dimensions of the input_backprop is
@@ -261,40 +271,31 @@
     }
   }
 
-  DSizes<TensorIndex, NumDims> strides;
-  for (int i = 0; i < NumDims; i++) {
-    strides[i] = 1;
-  }
-  if (isColMajor) {
-    strides[1] = stridePlanes;
-    strides[2] = strideRows;
-    strides[3] = strideCols;
-  } else {
-    strides[NumDims - 2] = stridePlanes;
-    strides[NumDims - 3] = strideRows;
-    strides[NumDims - 4] = strideCols;
-  }
-
   return choose(
       Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
       kernel.reverse(kernel_reverse)
+          .shuffle(kernel_shuffle)
           .reshape(kernel_dims)
+          .eval()
           .contract(output_backward
                         .extract_volume_patches(
                             kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
-                            stridePlanes, strideRows, strideCols, padding_ztop,
-                            padding_zbottom, padding_top, padding_bottom,
+                            plane_stride, row_stride, col_stride, padding_top_z,
+                            padding_bottom_z, padding_top, padding_bottom,
                             padding_left, padding_right)
                         .reshape(pre_contract_dims),
                     contract_dims)
           .reshape(post_contract_dims),
       output_backward
           .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
-                                  stridePlanes, strideRows, strideCols,
-                                  padding_ztop, padding_zbottom, padding_top,
+                                  plane_stride, row_stride, col_stride,
+                                  padding_top_z, padding_bottom_z, padding_top,
                                   padding_bottom, padding_left, padding_right)
           .reshape(pre_contract_dims)
-          .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
+          .contract(kernel.reverse(kernel_reverse)
+                        .shuffle(kernel_shuffle)
+                        .reshape(kernel_dims)
+                        .eval(),
                     contract_dims)
           .reshape(post_contract_dims));
 }
@@ -322,48 +323,69 @@
  */
 template <typename OutputBackward, typename Input>
 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
-    internal::traits<OutputBackward>::Layout == ColMajor,
-    const TensorShufflingOp<
-        const array<typename internal::traits<OutputBackward>::Index, 5>,
-        const TensorReverseOp<
-            const array<bool, 5>,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<OutputBackward>::Index,
-                             5>,
+    internal::traits<Input>::Layout == ColMajor,
+    const TensorReverseOp<
+        const Eigen::array<typename internal::traits<Input>::Index,
+                           internal::traits<Input>::NumDimensions>,
+        const Eigen::TensorShufflingOp<
+            const Eigen::array<typename internal::traits<Input>::Index,
+                               internal::traits<Input>::NumDimensions>,
+            const Eigen::TensorReshapingOp<
+                const Eigen::DSizes<typename internal::traits<Input>::Index,
+                                    internal::traits<Input>::NumDimensions>,
                 const TensorContractionOp<
                     const array<
-                        IndexPair<typename internal::traits<Input>::Index>, 2>,
+                        IndexPair<typename internal::traits<Input>::Index>, 1>,
+                    const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const Eigen::TensorShufflingOp<
+                            const Eigen::array<
+                                typename internal::traits<Input>::Index,
+                                internal::traits<Input>::NumDimensions>,
+                            const OutputBackward>>>,
                     const TensorReshapingOp<
                         const DSizes<typename internal::traits<Input>::Index,
-                                     3>,
-                        const Input>,
-                    const TensorReshapingOp<
-                        const DSizes<
-                            typename internal::traits<OutputBackward>::Index,
-                            4>,
+                                     2>,
                         const TensorVolumePatchOp<
                             Dynamic, Dynamic, Dynamic,
-                            const OutputBackward> > > > > >,
-    const TensorShufflingOp<
-        const array<typename internal::traits<OutputBackward>::Index, 5>,
-        const TensorReverseOp<
-            const array<bool, 5>,
-            const TensorReshapingOp<
-                const DSizes<typename internal::traits<OutputBackward>::Index,
-                             5>,
+                            const Eigen::TensorForcedEvalOp<
+                                const Eigen::TensorShufflingOp<
+                                    const Eigen::array<
+                                        typename internal::traits<Input>::Index,
+                                        internal::traits<Input>::NumDimensions>,
+                                    const Input>>>>>>>>,
+    const TensorReverseOp<
+        const Eigen::array<typename internal::traits<Input>::Index,
+                           internal::traits<Input>::NumDimensions>,
+        const Eigen::TensorShufflingOp<
+            const Eigen::array<typename internal::traits<Input>::Index,
+                               internal::traits<Input>::NumDimensions>,
+            const Eigen::TensorReshapingOp<
+                const Eigen::DSizes<typename internal::traits<Input>::Index,
+                                    internal::traits<Input>::NumDimensions>,
                 const TensorContractionOp<
                     const array<
-                        IndexPair<typename internal::traits<Input>::Index>, 2>,
-                    const TensorReshapingOp<
-                        const DSizes<
-                            typename internal::traits<OutputBackward>::Index,
-                            4>,
-                        const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
-                                                  const OutputBackward> >,
+                        IndexPair<typename internal::traits<Input>::Index>, 1>,
                     const TensorReshapingOp<
                         const DSizes<typename internal::traits<Input>::Index,
-                                     3>,
-                        const Input> > > > > >::type
+                                     2>,
+                        const TensorVolumePatchOp<
+                            Dynamic, Dynamic, Dynamic,
+                            const Eigen::TensorForcedEvalOp<
+                                const Eigen::TensorShufflingOp<
+                                    const Eigen::array<
+                                        typename internal::traits<Input>::Index,
+                                        internal::traits<Input>::NumDimensions>,
+                                    const Input>>>>,
+                    const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+                        const DSizes<typename internal::traits<Input>::Index,
+                                     2>,
+                        const Eigen::TensorShufflingOp<
+                            const Eigen::array<
+                                typename internal::traits<Input>::Index,
+                                internal::traits<Input>::NumDimensions>,
+                            const OutputBackward>>>>>>>>::type
 CuboidConvolutionBackwardKernel(
     const Input& input, const OutputBackward& output_backward,
     typename internal::traits<Input>::Index kernelPlanes,
@@ -374,11 +396,11 @@
   typedef typename internal::traits<Input>::Index TensorIndex;
   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
                    internal::traits<Input>::NumDimensions,
-                   internal::traits<Input>::Layout, TensorIndex> >
+                   internal::traits<Input>::Layout, TensorIndex>>
       in(input);
   TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
                    internal::traits<OutputBackward>::NumDimensions,
-                   internal::traits<OutputBackward>::Layout, TensorIndex> >
+                   internal::traits<OutputBackward>::Layout, TensorIndex>>
       out(output_backward);
 
   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -392,6 +414,13 @@
                           internal::traits<OutputBackward>::NumDimensions,
                       YOU_MADE_A_PROGRAMMING_MISTAKE);
 
+  // We do not support higher dimensional backward convolutions, or convolutions
+  // without batch dimension.
+  // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+  // dimension in eigen_backward_cuboid_convolutions_test.cc.
+  EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+                      YOU_MADE_A_PROGRAMMING_MISTAKE);
+
   const TensorIndex inputPlanes =
       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
   const TensorIndex inputRows =
@@ -406,213 +435,174 @@
   const TensorIndex outputCols =
       isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
 
+  // Number of filters. This is the same as the output depth.
   const TensorIndex kernelFilters =
       isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+  // Number of channels. This is the same as the input depth.
   const TensorIndex kernelChannels =
       isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
 
-  TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
-  const TensorIndex size_z =
-      Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
-  const TensorIndex size_y =
-      Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
-  const TensorIndex size_x =
-      Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
+  // Number of batches in the input tensor.
+  const TensorIndex batch =
+      isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
 
-  // Infer padding type.
-  if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
-    // SAME padding.
-    const TensorIndex dz = numext::maxi<TensorIndex>(
-        0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
-    const TensorIndex dy = numext::maxi<TensorIndex>(
-        0, (size_y - 1) * strideRows + kernelRows - inputRows);
-    const TensorIndex dx = numext::maxi<TensorIndex>(
-        0, (size_x - 1) * strideCols + kernelCols - inputCols);
+  // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+  // effective kernel planes/rows/cols are always the same as the kernel itself
+  // (see eigen_spatial_convolutions for details).
+  const TensorIndex kernelPlanesEff = kernelPlanes;
+  const TensorIndex kernelRowsEff = kernelRows;
+  const TensorIndex kernelColsEff = kernelCols;
 
-    forward_pad_z = dz / 2;
-    forward_pad_y = dy / 2;
-    forward_pad_x = dx / 2;
-  } else {
-    // VALID padding.
-    forward_pad_z = 0;
-    forward_pad_y = 0;
-    forward_pad_x = 0;
-  }
+  // Compute forward padding from input and output_backward dimensions.
+  const TensorIndex padPlanes = numext::maxi<Index>(
+      0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
+  const TensorIndex padRows = numext::maxi<Index>(
+      0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows);
+  const TensorIndex padCols = numext::maxi<Index>(
+      0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
 
-  const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
-  const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
-  const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
+  const TensorIndex padding_top_z = padPlanes / 2;
+  const TensorIndex padding_top = padRows / 2;
+  const TensorIndex padding_left = padCols / 2;
 
-  const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
-                                      (outputPlanes - 1) * stridePlanes - 1 -
-                                      padding_ztop;
-  const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
-                                     (outputRows - 1) * strideRows - 1 -
-                                     padding_top;
-  const TensorIndex padding_right = inputCols + kernelCols - 1 -
-                                    (outputCols - 1) * strideCols - 1 -
-                                    padding_left;
+  // Compute paddings for output_backward before extracting patches.
+  const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+  const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+  const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+  const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+  const auto padded_out_rows = inputRows + kernelRows - 1;
+  const auto padded_out_cols = inputCols + kernelCols - 1;
+  const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+  const auto top_pad_rows = kernelRows - 1 - padding_top;
+  const auto left_pad_cols = kernelCols - 1 - padding_left;
+  const auto bottom_pad_planes =
+      padded_out_planes - expanded_out_planes - top_pad_planes;
+  const auto bottom_pad_rows =
+      padded_out_rows - expanded_out_rows - top_pad_rows;
+  const auto right_pad_cols =
+      padded_out_cols - expanded_out_cols - left_pad_cols;
 
-  eigen_assert(padding_ztop >= 0);
-  eigen_assert(padding_zbottom >= 0);
-  eigen_assert(padding_top >= 0);
-  eigen_assert(padding_left >= 0);
-  eigen_assert(padding_bottom >= 0);
-  eigen_assert(padding_right >= 0);
-
-  // The output_backward has dimensions out_depth X out_plaens X out_rows X
-  // out_cols X OTHERS
-  // When we extract the image patches from output_backward (with input as the
-  // kernel), it will have dimensions
-  //  (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
-  //  kernel_rows * kernel_cols) X OTHERS
-  DSizes<TensorIndex, 4> pre_contract_dims;
+  // Reorder output_backward dimensions.
+  array<TensorIndex, 5> output_backward_shuffle;
   if (isColMajor) {
-    pre_contract_dims[0] = kernelFilters;
-    pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
-    pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
-    pre_contract_dims[3] = 1;
-    for (int i = 4; i < NumDims; ++i) {
-      pre_contract_dims[3] *= out.dimension(i);
-    }
+    // From: [out_depth, out_planes, out_rows, out_cols, batch]
+    // To:   [batch, out_planes, out_rows, out_cols, out_depth]
+    output_backward_shuffle = {4, 1, 2, 3, 0};
   } else {
-    pre_contract_dims[3] = kernelFilters;
-    pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
-    pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
-    pre_contract_dims[0] = 1;
-    for (int i = 0; i < NumDims - 4; ++i) {
-      pre_contract_dims[0] *= out.dimension(i);
-    }
+    // From: [batch, out_cols, out_rows, out_planes, out_depth]
+    // To:   [out_depth, out_cols, out_rows, out_planes, batch]
+    output_backward_shuffle = {4, 1, 2, 3, 0};
   }
 
-  // The input has dimensions in_depth X (input_planes * input_rows *
-  // input_cols) X OTHERS
-  DSizes<TensorIndex, 3> input_dims;
+  // Reorder input dimensions.
+  array<TensorIndex, 5> input_shuffle;
+  if (isColMajor) {
+    // From: [in_depth, in_planes, in_rows, in_cols, batch]
+    // To:   [in_depth, batch, in_planes, in_rows, in_cols]
+    input_shuffle = {0, 4, 1, 2, 3};
+  } else {
+    // From: [batch, in_cols, in_rows, in_planes, in_depth]
+    // To:   [in_cols, in_rows, in_planes, batch, in_depth]
+    input_shuffle = {1, 2, 3, 0, 4};
+  }
+
+  // Input is playing the role of a "kernel" in this convolution.
+  DSizes<TensorIndex, 2> input_dims;
   if (isColMajor) {
     input_dims[0] = kernelChannels;
-    input_dims[1] = inputRows * inputCols * inputPlanes;
-    input_dims[2] = 1;
-    for (int i = 4; i < NumDims; ++i) {
-      input_dims[2] *= in.dimension(i);
-    }
-    eigen_assert(input_dims[2] == pre_contract_dims[3]);
+    input_dims[1] = batch * inputPlanes * inputRows * inputCols;
   } else {
-    input_dims[2] = kernelChannels;
-    input_dims[1] = inputRows * inputCols * inputPlanes;
-    input_dims[0] = 1;
-    for (int i = 0; i < NumDims - 4; ++i) {
-      input_dims[0] *= in.dimension(i);
-    }
-    eigen_assert(input_dims[0] == pre_contract_dims[0]);
+    input_dims[1] = kernelChannels;
+    input_dims[0] = inputCols * inputRows * inputPlanes * batch;
   }
 
-  // We will contract along dimensions (1, 2) in and (1, 3) in out, if
-  // this is col-major.
-  // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
-  array<IndexPair<TensorIndex>, 2> contract_dims;
+  // Molds the output of the patch extraction result into a 2D tensor:
+  // - the first dimension (dims[0]): the patch values to be multiplied with the
+  // kernels
+  // - the second dimension (dims[1]): everything else
+  DSizes<TensorIndex, 2> pre_contract_dims;
   if (isColMajor) {
-    // col-major: in.contract(output.patches)
-    contract_dims[0] = IndexPair<TensorIndex>(1, 1);
-    contract_dims[1] = IndexPair<TensorIndex>(2, 3);
+    pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+    pre_contract_dims[1] =
+        kernelPlanes * kernelRows * kernelCols * kernelFilters;
   } else {
-    // row-major: output.patches.contract(in)
-    contract_dims[0] = IndexPair<TensorIndex>(0, 0);
-    contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+    pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+    pre_contract_dims[0] =
+        kernelFilters * kernelCols * kernelRows * kernelPlanes;
   }
 
-  // After the contraction, the kernel will have dimension
-  //   in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
-  // We will need to shuffle the first two dimensions and reverse the spatial
-  // dimensions.
-  // The end shape is:
-  //   out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+  // We will contract along the collapsed dimension that contains the
+  // batch, inputPlanes, inputRows and inputCols.
+  array<IndexPair<TensorIndex>, 1> contract_dims;
+  contract_dims[0] = IndexPair<TensorIndex>(1, 0);
 
-  // This is the shape of the kernel *before* the shuffling.
-  DSizes<TensorIndex, 5> kernel_dims;
+  // Dimensions after contraction.
+  DSizes<TensorIndex, NumDims> post_contract_dims;
   if (isColMajor) {
-    kernel_dims[0] = kernelChannels;
-    kernel_dims[1] = kernelFilters;
-    kernel_dims[2] = kernelPlanes;
-    kernel_dims[3] = kernelRows;
-    kernel_dims[4] = kernelCols;
+    post_contract_dims[0] = kernelChannels;
+    post_contract_dims[1] = kernelPlanes;
+    post_contract_dims[2] = kernelRows;
+    post_contract_dims[3] = kernelCols;
+    post_contract_dims[4] = kernelFilters;
   } else {
-    kernel_dims[0] = kernelCols;
-    kernel_dims[1] = kernelRows;
-    kernel_dims[2] = kernelPlanes;
-    kernel_dims[3] = kernelFilters;
-    kernel_dims[4] = kernelChannels;
+    post_contract_dims[0] = kernelFilters;
+    post_contract_dims[1] = kernelCols;
+    post_contract_dims[2] = kernelRows;
+    post_contract_dims[3] = kernelPlanes;
+    post_contract_dims[4] = kernelChannels;
   }
 
-  // Flip filters and channels.
+  // Reorder output of contraction to valid filter shape.
   array<TensorIndex, 5> kernel_shuffle;
   if (isColMajor) {
-    kernel_shuffle[0] = 1;
-    kernel_shuffle[1] = 0;
-    kernel_shuffle[2] = 2;
-    kernel_shuffle[3] = 3;
-    kernel_shuffle[4] = 4;
+    // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+    // To:   [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+    kernel_shuffle = {4, 0, 1, 2, 3};
   } else {
-    kernel_shuffle[0] = 0;
-    kernel_shuffle[1] = 1;
-    kernel_shuffle[2] = 2;
-    kernel_shuffle[3] = 4;
-    kernel_shuffle[4] = 3;
+    // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+    // To:   [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+    kernel_shuffle = {1, 2, 3, 4, 0};
   }
 
-  // Reverse the spatial dimensions.
-  array<bool, 5> kernel_reverse;
+  // Reverse kernel backprop dimensions.
+  array<TensorIndex, 5> kernel_reverse;
   if (isColMajor) {
-    kernel_reverse[0] = false;
-    kernel_reverse[1] = false;
-    kernel_reverse[2] = true;
-    kernel_reverse[3] = true;
-    kernel_reverse[4] = true;
+    kernel_reverse = {false, false, true, true, true};
   } else {
-    kernel_reverse[0] = true;
-    kernel_reverse[1] = true;
-    kernel_reverse[2] = true;
-    kernel_reverse[3] = false;
-    kernel_reverse[4] = false;
+    kernel_reverse = {true, true, true, false, false};
   }
 
-  DSizes<TensorIndex, NumDims> strides;
-  for (int i = 0; i < NumDims; i++) {
-    strides[i] = 1;
-  }
-  if (isColMajor) {
-    strides[1] = stridePlanes;
-    strides[2] = strideRows;
-    strides[3] = strideCols;
-  } else {
-    strides[NumDims - 2] = stridePlanes;
-    strides[NumDims - 3] = strideRows;
-    strides[NumDims - 4] = strideCols;
-  }
-  return choose(
-      Cond<internal::traits<Input>::Layout == ColMajor>(),
-      input.reshape(input_dims)
-          .contract(output_backward
+  // Create convolution input (aka source of patches) from output backward
+  // tensor by shuffling dimensions.
+  const auto the_input =
+      output_backward.shuffle(output_backward_shuffle).eval();
+
+  // Create convolution kernel (aka filter) from input by shuffling and
+  // reshaping.
+  const auto the_kernel =
+      input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+  return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+                the_kernel.contract(
+                    the_input
                         .extract_volume_patches(
                             inputPlanes, inputRows, inputCols, 1, 1, 1,
                             stridePlanes, strideRows, strideCols,
-
-                            padding_ztop, padding_zbottom, padding_top,
-                            padding_bottom, padding_left, padding_right)
+                            top_pad_planes, bottom_pad_planes, top_pad_rows,
+                            bottom_pad_rows, left_pad_cols, right_pad_cols)
                         .reshape(pre_contract_dims),
-                    contract_dims)
-          .reshape(kernel_dims)
-          .reverse(kernel_reverse)
-          .shuffle(kernel_shuffle),
-      output_backward
-          .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
-                                  stridePlanes, strideRows, strideCols,
-                                  padding_ztop, padding_zbottom, padding_top,
-                                  padding_bottom, padding_left, padding_right)
-          .reshape(pre_contract_dims)
-          .contract(input.reshape(input_dims), contract_dims)
-          .reshape(kernel_dims)
-          .reverse(kernel_reverse)
-          .shuffle(kernel_shuffle));
+                    contract_dims),
+                the_input
+                    .extract_volume_patches(
+                        inputPlanes, inputRows, inputCols, 1, 1, 1,
+                        stridePlanes, strideRows, strideCols, top_pad_planes,
+                        bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+                        left_pad_cols, right_pad_cols)
+                    .reshape(pre_contract_dims)
+                    .contract(the_kernel, contract_dims))
+      .reshape(post_contract_dims)
+      .shuffle(kernel_shuffle)
+      .reverse(kernel_reverse);
 }
 
 }  // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index cb0a76d..960920c 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -189,14 +189,19 @@
   }
 #endif
 
-  // Reorder the dimensions to filters X patch_rows X patch_cols X channels
+  // Reorder the dimensions to:
+  //   filters x patch_rows x patch_cols x channels
   array<TensorIndex, 4> kernel_shuffle;
   if (isColMajor) {
+    //  From: filters x channels x rows x cols
+    //  To:   filters x rows x cols x channels
     kernel_shuffle[0] = 0;
     kernel_shuffle[1] = 2;
     kernel_shuffle[2] = 3;
     kernel_shuffle[3] = 1;
   } else {
+    //  From: cols x rows x channels x filters
+    //  To:   channels x cols x rows x filters
     kernel_shuffle[0] = 2;
     kernel_shuffle[1] = 0;
     kernel_shuffle[2] = 1;
@@ -233,8 +238,8 @@
     }
   }
 
-  // We will contract along the fused dimension that contains the kernelFilters,
-  // the kernelRows and the kernelCols.
+  // We will contract along the collapsed dimension that contains the
+  // kernelFilters, the kernelRows and the kernelCols.
   array<IndexPair<TensorIndex>, 1> contract_dims;
   if (isColMajor) {
     // col-major: kernel.contract(output.patches)
@@ -327,23 +332,16 @@
             const TensorReshapingOp<
                 const DSizes<typename internal::traits<Input>::Index, 2>,
                 const OutputBackward>,
-            const TensorShufflingOp<
-                const array<typename internal::traits<OutputBackward>::Index,
-                            2>,
-                const TensorReshapingOp<
-                    const DSizes<typename internal::traits<Input>::Index, 2>,
-                    const TensorImagePatchOp<Dynamic, Dynamic,
-                                             const Input> > > > >,
+            const TensorReshapingOp<
+                const DSizes<typename internal::traits<Input>::Index, 2>,
+                const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
     TensorReshapingOp<
         const DSizes<typename internal::traits<Input>::Index, 4>,
         const TensorContractionOp<
             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
-            const TensorShufflingOp<
-                const array<typename internal::traits<OutputBackward>::Index,
-                            2>,
-                const TensorReshapingOp<
-                    const DSizes<typename internal::traits<Input>::Index, 2>,
-                    const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+            const TensorReshapingOp<
+                const DSizes<typename internal::traits<Input>::Index, 2>,
+                const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
             const TensorReshapingOp<
                 const DSizes<typename internal::traits<Input>::Index, 2>,
                 const OutputBackward> > > >::type
@@ -451,12 +449,16 @@
     eigen_assert(output_dims[0] == pre_contract_dims[0]);
   }
 
-  array<TensorIndex, 2> shuffle_dims;
-  shuffle_dims[0] = 1;
-  shuffle_dims[1] = 0;
-
+  // We will contract along the collapsed dimension that contains the
+  // outputCols, outputRows and OTHERS.
   array<IndexPair<TensorIndex>, 1> contract_dims;
-  contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+  if (isColMajor) {
+    // col-major: output_backward.contract(input.patches)
+    contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+  } else {
+    // row-major: input.patches.contract(output_backward)
+    contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+  }
 
   // After the contraction, the kernel will have the desired shape
   // out_depth X in_shape X kernel_rows X kernel_cols
@@ -482,8 +484,7 @@
                       kernelRows, kernelCols, row_stride, col_stride,
                       row_in_stride, col_in_stride, 1, 1, padding_top,
                       padding_bottom, padding_left, padding_right, OutScalar(0))
-                  .reshape(pre_contract_dims)
-                  .shuffle(shuffle_dims),
+                  .reshape(pre_contract_dims),
               contract_dims)
           .reshape(kernel_dims),
       input
@@ -492,7 +493,6 @@
                                  padding_top, padding_bottom, padding_left,
                                  padding_right, OutScalar(0))
           .reshape(pre_contract_dims)
-          .shuffle(shuffle_dims)
           .contract(output_backward.reshape(output_dims), contract_dims)
           .reshape(kernel_dims));
 }
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9..673ec14 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@
   const int output_cols = input_cols - patch_cols + 1;
   const int output_planes = input_planes - patch_planes + 1;
 
-  Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+  // TODO(ezhulenev): Support backward kernel convolution without batch
+  // dimension.
+  Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+                         /*num_batches*/ 1);
   Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
                           patch_cols);
-  Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
-                                   output_cols);
+  Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+                                   output_cols, /*num_batches*/ 1);
 
   output_backward = output_backward.constant(11.0f) + output_backward.random();
   input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@
                   if (output_i >= 0 && output_i < output_planes &&
                       output_j >= 0 && output_j < output_rows &&
                       output_k >= 0 && output_k < output_cols) {
-                    expected +=
-                        input(id, i, j, k) *
-                        output_backward(od, output_i, output_j, output_k);
+                    expected += input(id, i, j, k, /*batch*/ 0) *
+                                output_backward(od, output_i, output_j,
+                                                output_k, /*batch*/ 0);
                   }
                 }
               }
@@ -1311,12 +1314,14 @@
   const int output_cols = input_cols - patch_cols + 1;
   const int output_planes = input_planes - patch_planes + 1;
 
-  Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
-                                   input_depth);
+  // TODO(ezhulenev): Support backward kernel convolution without batch
+  // dimension.
+  Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+                                   input_planes, input_depth);
   Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
                                     input_depth, output_depth);
-  Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
-                                             output_planes, output_depth);
+  Tensor<float, 5, RowMajor> output_backward(
+      /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
 
   output_backward = output_backward.constant(11.0f) + output_backward.random();
   input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@
                   if (output_i >= 0 && output_i < output_planes &&
                       output_j >= 0 && output_j < output_rows &&
                       output_k >= 0 && output_k < output_cols) {
-                    expected +=
-                        input(k, j, i, id) *
-                        output_backward(output_k, output_j, output_i, od);
+                    expected += input(/*batch*/ 0, k, j, i, id) *
+                                output_backward(/*batch*/ 0, output_k, output_j,
+                                                output_i, od);
                   }
                 }
               }
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
index 46ad38f..87e41b8 100644
--- a/tensorflow/core/kernels/eigen_benchmark.h
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -76,6 +76,9 @@
 
   void SpatialConvolutionBackwardInput(Dimensions input_dims,
                                        Dimensions filter_dims) {
+    using OutputBackward = TTypes<float, 4>::ConstTensor;
+    using InputBackward = TTypes<float, 4>::Tensor;
+
     Dimensions output_dims(input_dims[0],    // batch
                            input_dims[1],    // input_height
                            input_dims[2],    // input_width
@@ -85,37 +88,37 @@
     Eigen::Index input_rows = input_dims[1];
     Eigen::Index input_cols = input_dims[2];
 
-    Scalar* input_data =
-        static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
     Scalar* filter_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
-    Scalar* output_data =
+    Scalar* output_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+    Scalar* input_backward_data =
+        static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
 
-    device_.memset(input_data, 123, BufferSize(input_dims));
     device_.memset(filter_data, 123, BufferSize(filter_dims));
+    device_.memset(output_backward_data, 123, BufferSize(output_dims));
 
-    Input input(input_data, input_dims);
     Filter filter(filter_data, filter_dims);
-    Output output(output_data, output_dims);
+    OutputBackward output_backward(output_backward_data, output_dims);
+    InputBackward input_backward(input_backward_data, input_dims);
 
     ::tensorflow::testing::StartTiming();
     for (int i = 0; i < iters_; ++i) {
-      output.device(device_) = Eigen::SpatialConvolutionBackwardInput(
-          filter, input, input_rows, input_cols);
-      tensorflow::testing::DoNotOptimize(output);
+      input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+          filter, output_backward, input_rows, input_cols);
+      tensorflow::testing::DoNotOptimize(input_backward);
     }
     ::tensorflow::testing::StopTiming();
 
-    device_.deallocate(input_data);
     device_.deallocate(filter_data);
-    device_.deallocate(output_data);
+    device_.deallocate(output_backward_data);
+    device_.deallocate(input_backward_data);
   }
 
   void SpatialConvolutionBackwardKernel(Dimensions input_dims,
                                         Dimensions filter_dims) {
     using OutputBackward = TTypes<float, 4>::ConstTensor;
-    using FilterGrad = TTypes<float, 4>::Tensor;
+    using FilterBackward = TTypes<float, 4>::Tensor;
 
     Dimensions output_dims(input_dims[0],    // batch
                            input_dims[1],    // input_height
@@ -130,7 +133,7 @@
         static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
     Scalar* output_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
-    Scalar* filter_data =
+    Scalar* filter_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
 
     device_.memset(input_data, 123, BufferSize(input_dims));
@@ -138,19 +141,19 @@
 
     Input input(input_data, input_dims);
     OutputBackward output_backward(output_backward_data, input_dims);
-    FilterGrad filter_grad(filter_data, filter_dims);
+    FilterBackward filter_backward(filter_backward_data, filter_dims);
 
     ::tensorflow::testing::StartTiming();
     for (int i = 0; i < iters_; ++i) {
-      filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+      filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
           input, output_backward, filter_rows, filter_cols);
-      tensorflow::testing::DoNotOptimize(filter_grad);
+      tensorflow::testing::DoNotOptimize(filter_backward);
     }
     ::tensorflow::testing::StopTiming();
 
     device_.deallocate(input_data);
     device_.deallocate(output_backward_data);
-    device_.deallocate(filter_data);
+    device_.deallocate(filter_backward_data);
   }
 
  private:
@@ -215,42 +218,45 @@
                            input_dims[3],    // input_planes
                            filter_dims[4]);  // filter_count
 
+    using OutputBackward = TTypes<float, 5>::ConstTensor;
+    using InputBackward = TTypes<float, 5>::Tensor;
+
     // Assuming that the convolution had SAME padding.
     Eigen::Index input_rows = input_dims[1];
     Eigen::Index input_cols = input_dims[2];
     Eigen::Index input_planes = input_dims[3];
 
-    Scalar* input_data =
-        static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
     Scalar* filter_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
-    Scalar* output_data =
+    Scalar* output_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+    Scalar* input_backward_data =
+        static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
 
-    device_.memset(input_data, 123, BufferSize(input_dims));
     device_.memset(filter_data, 123, BufferSize(filter_dims));
+    device_.memset(output_backward_data, 123, BufferSize(output_dims));
 
-    Input input(input_data, input_dims);
     Filter filter(filter_data, filter_dims);
-    Output output(output_data, output_dims);
+    OutputBackward output_backward(output_backward_data, output_dims);
+    InputBackward input_backward(input_backward_data, input_dims);
 
     ::tensorflow::testing::StartTiming();
     for (int i = 0; i < iters_; ++i) {
-      output.device(device_) = Eigen::CuboidConvolutionBackwardInput(
-          filter, input, input_planes, input_rows, input_cols);
-      tensorflow::testing::DoNotOptimize(output);
+      input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+          filter, output_backward, input_planes, input_rows, input_cols);
+      tensorflow::testing::DoNotOptimize(input_backward);
     }
     ::tensorflow::testing::StopTiming();
 
-    device_.deallocate(input_data);
     device_.deallocate(filter_data);
-    device_.deallocate(output_data);
+    device_.deallocate(output_backward_data);
+    device_.deallocate(input_backward_data);
   }
 
   void CuboidConvolutionBackwardKernel(Dimensions input_dims,
                                        Dimensions filter_dims) {
     using OutputBackward = TTypes<float, 5>::ConstTensor;
-    using FilterGrad = TTypes<float, 5>::Tensor;
+    using FilterBackward = TTypes<float, 5>::Tensor;
 
     Dimensions output_dims(input_dims[0],    // batch
                            input_dims[1],    // input_height
@@ -267,7 +273,7 @@
         static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
     Scalar* output_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
-    Scalar* filter_data =
+    Scalar* filter_backward_data =
         static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
 
     device_.memset(input_data, 123, BufferSize(input_dims));
@@ -275,19 +281,19 @@
 
     Input input(input_data, input_dims);
     OutputBackward output_backward(output_backward_data, output_dims);
-    FilterGrad filter_grad(filter_data, filter_dims);
+    FilterBackward filter_backward(filter_backward_data, filter_dims);
 
     ::tensorflow::testing::StartTiming();
     for (int i = 0; i < iters_; ++i) {
-      filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+      filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
           input, output_backward, filter_planes, filter_rows, filter_cols);
-      tensorflow::testing::DoNotOptimize(filter_grad);
+      tensorflow::testing::DoNotOptimize(filter_backward);
     }
     ::tensorflow::testing::StopTiming();
 
     device_.deallocate(input_data);
     device_.deallocate(output_backward_data);
-    device_.deallocate(filter_data);
+    device_.deallocate(filter_backward_data);
   }
 
  private:
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
index 2a8308e..ec949dd 100644
--- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -48,8 +48,10 @@
 
   benchmark.SpatialConvolution(input_dims, filter_dims);
 
-  auto output_size = input_dims.TotalSize();
-  auto flops = output_size * (input_depth * filter_height * filter_width);
+  auto num_computed_elements =
+      (input_dims.TotalSize() / input_depth) * filter_count;
+  auto flops =
+      num_computed_elements * (input_depth * filter_height * filter_width);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
 
@@ -75,8 +77,9 @@
 
   benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
 
-  auto output_size = input_dims.TotalSize();
-  auto flops = output_size * (input_depth * filter_height * filter_width);
+  auto num_computed_elements = input_dims.TotalSize();
+  auto flops =
+      num_computed_elements * (input_depth * filter_height * filter_width);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
 
@@ -102,8 +105,9 @@
 
   benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
 
-  auto filter_size = filter_dims.TotalSize();
-  auto flops = filter_size * (input_batches * input_height * input_width);
+  auto num_computed_elements = filter_dims.TotalSize();
+  auto flops =
+      num_computed_elements * (input_batches * input_height * input_width);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
 
@@ -123,6 +127,7 @@
 #define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL)          \
   static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
                               FW)(int iters) {                            \
+    ::tensorflow::testing::SetLabel(LABEL);                               \
     SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW);                \
   }                                                                       \
   BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
@@ -130,6 +135,7 @@
 #define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL)      \
   static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
                               FH, FW)(int iters) {                            \
+    ::tensorflow::testing::SetLabel(LABEL);                                   \
     SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW);       \
   }                                                                           \
   BENCHMARK(                                                                  \
@@ -138,6 +144,7 @@
 #define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL)      \
   static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
                               FH, FW)(int iters) {                             \
+    ::tensorflow::testing::SetLabel(LABEL);                                    \
     SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW);       \
   }                                                                            \
   BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC,   \
@@ -263,8 +270,9 @@
 
   benchmark.CuboidConvolution(input_dims, filter_dims);
 
-  auto output_size = input_dims.TotalSize();
-  auto flops = output_size *
+  auto num_computed_elements =
+      (input_dims.TotalSize() / input_depth) * filter_count;
+  auto flops = num_computed_elements *
                (input_depth * filter_height * filter_width * filter_planes);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
@@ -292,8 +300,8 @@
 
   benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
 
-  auto output_size = input_dims.TotalSize();
-  auto flops = output_size *
+  auto num_computed_elements = input_dims.TotalSize();
+  auto flops = num_computed_elements *
                (input_depth * filter_height * filter_width * filter_planes);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
@@ -321,9 +329,9 @@
 
   benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
 
-  auto filter_size = filter_dims.TotalSize();
-  auto flops =
-      filter_size * (input_batches * input_height * input_width * input_planes);
+  auto num_computed_elements = filter_dims.TotalSize();
+  auto flops = num_computed_elements *
+               (input_batches * input_height * input_width * input_planes);
   ::tensorflow::testing::ItemsProcessed(flops * iters);
 }
 
@@ -348,6 +356,7 @@
 #define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL)         \
   static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
                              FP)(int iters) {                                  \
+    ::tensorflow::testing::SetLabel(LABEL);                                    \
     CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP);               \
   }                                                                            \
   BENCHMARK(                                                                   \
@@ -356,6 +365,7 @@
 #define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
   static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
                              FH, FW, FP)(int iters) {                          \
+    ::tensorflow::testing::SetLabel(LABEL);                                    \
     CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP);  \
   }                                                                            \
   BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC,   \
@@ -365,6 +375,7 @@
                                       LABEL)                                   \
   static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C,    \
                              FC, FH, FW, FP)(int iters) {                      \
+    ::tensorflow::testing::SetLabel(LABEL);                                    \
     CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
   }                                                                            \
   BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC,  \
@@ -395,8 +406,17 @@
 BM_CuboidConvolutions(8,              // batch size
                       25, 25, 25, 4,  // input: height, width, panes, depth
                       16, 5, 5, 5,    // filter: count, height, width, panes
-                      "conv3d");
+                      "conv3d_depth4");
+BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
 
-BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
 
-BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f91..c41fbc4 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1362 @@
 
 namespace Eigen {
 
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+//   *0: kernelChannels * kernelDepth * kernelRows * kernelCols;
+//    1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+//    col major layout)
+//
+// With this dimensions:
+//   row - offset within a single patch (in code: patchId)
+//   col - index of the extracted patch (in code: patchIndex)
+//         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+          DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+          typename Index, typename nocontract_t, typename contract_t, int Side,
+          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+          int Alignment>
+class TensorContractionInputMapper<
+    Scalar_, Index, Side,
+    TensorEvaluator<const TensorReshapingOp<NewDimension,
+                                            const TensorVolumePatchOp<
+                                                Planes, Rows, Cols, ArgType> >,
+                    Device>,
+    nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+    inner_dim_reordered, Alignment> {
+ public:
+  typedef Scalar_ Scalar;
+  typedef TensorContractionInputMapper<
+      Scalar, Index, Side,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      Self;
+  typedef TensorContractionSubMapper<
+      Scalar, Index, Side,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      SubMapper;
+  typedef SubMapper VectorMapper;
+  typedef SubMapper LinearMapper;
+  typedef typename packet_traits<Scalar>::type Packet;
+
+  EIGEN_DEVICE_FUNC
+  TensorContractionInputMapper(
+      const TensorEvaluator<
+          const TensorReshapingOp<
+              NewDimension,
+              const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+          Device>& tensor,
+      const nocontract_t&, const nocontract_t&, const contract_t&,
+      const contract_t&)
+      : m_impl(tensor.impl().impl()) {
+    if (internal::traits<ArgType>::Layout == ColMajor) {
+      m_patch_depth = tensor.impl().dimensions()[0];
+      m_patch_planes = tensor.impl().dimensions()[1];
+      m_patch_rows = tensor.impl().dimensions()[2];
+      m_patch_cols = tensor.impl().dimensions()[3];
+      m_num_patches = tensor.impl().dimensions()[4];
+    } else {
+      const int NumDims = tensor.impl().dimensions().size();
+      m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+      m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+      m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+      m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+      m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+    }
+
+    // Strides for the output tensor.
+    // IMPORTANT: These strides are used to locate an element in a patch at a
+    // depth zero (channel), which is not quite the same as "traditional"
+    // stride.
+    m_rowStride = m_patch_planes;
+    m_colStride = m_patch_rows * m_rowStride;
+    m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+    m_otherStride = m_patchStride * m_num_patches;
+
+    m_outputPlanes = tensor.impl().outputPlanes();
+    m_outputRows = tensor.impl().outputRows();
+    m_outputCols = tensor.impl().outputCols();
+
+    m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+    m_plane_strides = tensor.impl().userPlaneStride();
+    m_row_strides = tensor.impl().userRowStride();
+    m_col_strides = tensor.impl().userColStride();
+
+    m_in_plane_strides = tensor.impl().userInPlaneStride();
+    m_in_row_strides = tensor.impl().userInRowStride();
+    m_in_col_strides = tensor.impl().userInColStride();
+
+    m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+    m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+    m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+    if (internal::traits<ArgType>::Layout == ColMajor) {
+      m_inputDepth = tensor.impl().impl().dimensions()[0];
+      m_inputPlanes = tensor.impl().impl().dimensions()[1];
+      m_inputRows = tensor.impl().impl().dimensions()[2];
+      m_inputCols = tensor.impl().impl().dimensions()[3];
+    } else {
+      const int NumDims = tensor.impl().impl().dimensions().size();
+      m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+      m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+      m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+      m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+    }
+
+    // Strides for navigating through the input tensor.
+    m_planeInputStride = m_inputDepth;
+    m_rowInputStride = m_inputDepth * m_inputPlanes;
+    m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+    m_patchInputStride =
+        m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+    m_planePaddingTop = tensor.impl().planePaddingTop();
+    m_rowPaddingTop = tensor.impl().rowPaddingTop();
+    m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+    m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+    m_fastInputPlaneStride =
+        internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+    m_fastInputRowStride =
+        internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+    m_fastInputColStride =
+        internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+    m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+    m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+    m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+    m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+    m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+    m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+    m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+    m_fastOutputPlanesRows =
+        internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+  }
+
+  EIGEN_DEVICE_FUNC
+  TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+      : m_impl(base_mapper.m_impl) {
+    m_patch_depth = base_mapper.m_patch_depth;
+    m_patch_planes = base_mapper.m_patch_planes;
+    m_patch_rows = base_mapper.m_patch_rows;
+    m_patch_cols = base_mapper.m_patch_cols;
+    m_num_patches = base_mapper.m_num_patches;
+
+    m_rowStride = base_mapper.m_rowStride;
+    m_colStride = base_mapper.m_colStride;
+    m_patchStride = base_mapper.m_patchStride;
+    m_otherStride = base_mapper.m_otherStride;
+
+    m_planeInputStride = base_mapper.m_planeInputStride;
+    m_rowInputStride = base_mapper.m_rowInputStride;
+    m_colInputStride = base_mapper.m_colInputStride;
+    m_patchInputStride = base_mapper.m_patchInputStride;
+    m_otherInputStride = base_mapper.m_otherInputStride;
+
+    m_inputDepth = base_mapper.m_inputDepth;
+    m_inputPlanes = base_mapper.m_inputPlanes;
+    m_inputRows = base_mapper.m_inputRows;
+    m_inputCols = base_mapper.m_inputCols;
+
+    m_outputPlanes = base_mapper.m_outputPlanes;
+    m_outputRows = base_mapper.m_outputRows;
+    m_outputCols = base_mapper.m_outputCols;
+
+    m_plane_strides = base_mapper.m_plane_strides;
+    m_row_strides = base_mapper.m_row_strides;
+    m_col_strides = base_mapper.m_col_strides;
+
+    m_in_plane_strides = base_mapper.m_in_plane_strides;
+    m_in_row_strides = base_mapper.m_in_row_strides;
+    m_in_col_strides = base_mapper.m_in_col_strides;
+
+    m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+    m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+    m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+    m_planePaddingTop = base_mapper.m_planePaddingTop;
+    m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+    m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+    m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+    m_fastNumPatches = base_mapper.m_fastNumPatches;
+    m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+    m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+    m_fastInputColStride = base_mapper.m_fastInputColStride;
+    m_fastRowStride = base_mapper.m_fastRowStride;
+    m_fastColStride = base_mapper.m_fastColStride;
+    m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+    m_fastOutputRows = base_mapper.m_fastOutputRows;
+    m_fastOutputCols = base_mapper.m_fastOutputCols;
+    m_fastDimZero = base_mapper.m_fastDimZero;
+    m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+  }
+
+  // If true, turns off some optimizations for loading packets since the image
+  // patches are "non-standard" such as there are non-trivial strides or
+  // inflations in the input.
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+    return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+           m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+           m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+    return SubMapper(*this, i, j);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+    return LinearMapper(*this, i, j);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+    Index planeIndex, rowIndex, colIndex, otherIndex;
+    computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+    return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+  }
+
+  // Load the coefficient at the patchIndex location instead of the usual
+  // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+  // gpu code.
+  EIGEN_DEVICE_FUNC
+  EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+    Index planeIndex, rowIndex, colIndex, otherIndex;
+    computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+    return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+    Index planeIndex, rowIndex, colIndex, otherIndex;
+    computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+    return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+  }
+
+  // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+  // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+    Index planeIndex, rowIndex, colIndex, otherIndex;
+    computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+    return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+    return m_impl;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+                                             const Index baseIndex) const {
+    const Index inputIndex = depth + baseIndex;
+    return m_impl.template packet<Unaligned>(inputIndex);
+  }
+
+ private:
+  friend class TensorContractionSubMapper<
+      Scalar, Index, Side,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>;
+
+  // Load coefficient from a patch specified by the "within patch offset"
+  // (patchId) and the precomputed indices of the first element of the patch.
+  EIGEN_DEVICE_FUNC
+  EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+                                       Index rowIndex, Index colIndex,
+                                       Index otherIndex) const {
+    // Find the offset of the element wrt the location of the first element.
+    const Index patchOffset = patchId / m_fastDimZero;
+
+    const Index colOffset = patchOffset / m_fastColStride;
+    const Index inputCol = colIndex + colOffset * m_in_col_strides;
+    const Index origInputCol =
+        (m_patch_col_inflate_strides == 1)
+            ? inputCol
+            : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+    const Index rowOffset =
+        (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+    const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+    const Index origInputRow =
+        (m_patch_row_inflate_strides == 1)
+            ? inputRow
+            : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+    const Index planeOffset =
+        patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+    const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+    const Index origInputPlane =
+        (m_patch_plane_inflate_strides == 1)
+            ? inputPlane
+            : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+    if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+        origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+        origInputPlane >= m_inputPlanes ||
+        (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+        (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+        (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+      return Scalar(0);
+    }
+
+    const Index depth = patchId - patchOffset * patchDepth();
+    const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+                             origInputRow * m_rowInputStride +
+                             origInputCol * m_colInputStride + otherIndex;
+
+    return m_impl.coeff(inputIndex);
+  }
+
+  // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+  // and `in_strides` equal to 1 (template specialization without templates).
+  EIGEN_DEVICE_FUNC
+  EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+                                               Index rowIndex, Index colIndex,
+                                               Index otherIndex) const {
+    eigen_assert(!nonStandardPatches());
+
+    // Find the offset of the element wrt the location of the first element.
+    const Index patchOffset = patchId / m_fastDimZero;
+
+    const Index colOffset = patchOffset / m_fastColStride;
+    const Index inputCol = colIndex + colOffset;
+
+    const Index rowOffset =
+        (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+    const Index inputRow = rowIndex + rowOffset;
+
+    const Index planeOffset =
+        patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+    const Index inputPlane = planeIndex + planeOffset;
+
+    if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+        inputRow >= m_inputRows || inputPlane < 0 ||
+        inputPlane >= m_inputPlanes) {
+      return Scalar(0);
+    }
+
+    const Index depth = patchId - patchOffset * patchDepth();
+    const Index inputIndex = depth + inputPlane * m_planeInputStride +
+                             inputRow * m_rowInputStride +
+                             inputCol * m_colInputStride + otherIndex;
+
+    return m_impl.coeff(inputIndex);
+  }
+
+  // Load packet from a patch specified by the "within patch offset"
+  // (patchId) and the precomputed indices of the first element of the patch.
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+                                        Index rowIndex, Index colIndex,
+                                        Index otherIndex) const {
+    const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+    EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+    eigen_assert(patchId <
+                 patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+    if (nonStandardPatches()) {
+      return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+                                    otherIndex);
+    }
+    return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+                              otherIndex);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+                                                Index rowIndex, Index colIndex,
+                                                Index otherIndex) const {
+    const Index packetSize = internal::unpacket_traits<Packet>::size;
+    EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+    eigen_assert(patchId <
+                 patchDepth() * patchPlanes() * patchRows() * patchCols());
+    eigen_assert(!nonStandardPatches());
+
+    if ((patchDepth() % packetSize) == 0) {
+      return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+                            otherIndex);
+    } else {
+      // Offsets and input calculation here are identical to
+      // loadCoeffStandard(...), but repeated twice.
+
+      const Index patchOffsets[2] = {
+          patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+      const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+                                   patchOffsets[1] / m_fastColStride};
+      eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+      const Index inputCols[2] = {colIndex + colOffsets[0],
+                                  colIndex + colOffsets[1]};
+      if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+        return internal::pset1<Packet>(Scalar(0));
+      }
+
+      if (inputCols[0] == inputCols[1]) {
+        const Index rowOffsets[2] = {
+            (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+            (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+        eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+        const Index inputRows[2] = {rowIndex + rowOffsets[0],
+                                    rowIndex + rowOffsets[1]};
+
+        if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+          return internal::pset1<Packet>(Scalar(0));
+        }
+
+        if (inputRows[0] == inputRows[1]) {
+          const Index planeOffsets[2] = {
+              patchOffsets[0] - colOffsets[0] * m_colStride -
+                  rowOffsets[0] * m_rowStride,
+              patchOffsets[1] - colOffsets[1] * m_colStride -
+                  rowOffsets[1] * m_rowStride};
+          eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+          const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+                                        planeIndex + planeOffsets[1]};
+
+          if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+            return internal::pset1<Packet>(Scalar(0));
+          }
+
+          if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+            const Index depth = patchId - patchOffsets[0] * patchDepth();
+            const Index inputIndex =
+                depth + inputPlanes[0] * m_planeInputStride +
+                inputRows[0] * m_rowInputStride +
+                inputCols[0] * m_colInputStride + otherIndex;
+            return m_impl.template packet<Unaligned>(inputIndex);
+          }
+        }
+      }
+    }
+
+    return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+                                  otherIndex);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+                                            Index rowIndex, Index colIndex,
+                                            Index otherIndex) const {
+    const Index packetSize = internal::unpacket_traits<Packet>::size;
+    EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+    eigen_assert(patchId <
+                 patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+    eigen_assert(!nonStandardPatches());
+    eigen_assert((patchDepth() % packetSize) == 0);
+
+    // Find the offset of the element wrt the location of the first element.
+    const Index patchOffset = patchId / m_fastDimZero;
+    eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+    const Index colOffset = patchOffset / m_fastColStride;
+    const Index inputCol = colIndex + colOffset;
+    const Index rowOffset =
+        (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+    const Index inputRow = rowIndex + rowOffset;
+    const Index planeOffset =
+        patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+    const Index inputPlane = planeIndex + planeOffset;
+
+    if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+        inputCol >= m_inputCols || inputRow >= m_inputRows ||
+        inputPlane >= m_inputPlanes) {
+      return internal::pset1<Packet>(Scalar(0));
+    }
+
+    const Index depth = patchId - patchOffset * patchDepth();
+    const Index inputIndex = depth + inputPlane * m_planeInputStride +
+                             inputRow * m_rowInputStride +
+                             inputCol * m_colInputStride + otherIndex;
+    return m_impl.template packet<Unaligned>(inputIndex);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+  packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+                         Index colIndex, Index otherIndex) const {
+    const int packetSize = internal::unpacket_traits<Packet>::size;
+    EIGEN_ALIGN_MAX
+    typename internal::remove_const<Scalar>::type values[packetSize];
+    for (int i = 0; i < packetSize; ++i) {
+      values[i] =
+          loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+    }
+    Packet rslt = internal::pload<Packet>(values);
+    return rslt;
+  }
+
+  // Precompute the indices (plane, row, col, other) of the first element of
+  // the given patch index, within the output tensor of the TensorVolumePatchOp.
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+      Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+      Index& otherIndex) const {
+    const int NumInputDims = array_size<
+        typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+    // Check if patchIndex might contain batch and other dimensions.
+    otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+    // Compute index of the patch within the batch (and other dimensions).
+    const Index patch3DIndex = (NumInputDims == 4)
+                                   ? patchIndex
+                                   : (patchIndex - otherIndex * m_num_patches);
+
+    otherIndex *= m_patchInputStride;
+
+    colIndex = patch3DIndex / m_fastOutputPlanesRows;
+    rowIndex =
+        (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+    planeIndex =
+        patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+    colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+    rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+    planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+  }
+
+  Index m_patch_depth;   // number of channels in the patch
+  Index m_patch_planes;  // number of planes in the patch
+  Index m_patch_rows;    // number of rows in the patch
+  Index m_patch_cols;    // number of columns in the patch
+  Index m_num_patches;   // number of patches to extract
+
+  // Strides for the output tensor.
+  Index m_rowStride;
+  Index m_colStride;
+  Index m_patchStride;
+  Index m_otherStride;
+
+  Index m_planeInputStride;  // Plane stride in the input tensor
+  Index m_rowInputStride;    // Row stride in the input tensor
+  Index m_colInputStride;    // Col stride in the input tensor
+  Index m_patchInputStride;  // Patch stride in the input tensor
+  Index m_otherInputStride;
+
+  Index m_inputDepth;   // Depth of the input tensor
+  Index m_inputPlanes;  // Number of planes in the input tensor
+  Index m_inputRows;    // Number of rows in the input tensor
+  Index m_inputCols;    // Number of cols in the input tensor
+
+  Index m_outputPlanes;      // Number of output planes
+  Index m_outputRows;        // Number of output rows
+  Index m_outputCols;        // Number of output cols
+  Index m_outputPlanesRows;  // Cached outputPlanes * outputRows.
+
+  Index m_plane_strides;  // User specified plane stride
+  Index m_row_strides;    // User specified row stride
+  Index m_col_strides;    // User specified col stride
+
+  // User specified plane/row/col atrous convolution strides.
+  Index m_in_plane_strides;
+  Index m_in_row_strides;
+  Index m_in_col_strides;
+
+  // User specified plane/row/col inflation strides in the image patch.
+  Index m_patch_plane_inflate_strides;
+  Index m_patch_row_inflate_strides;
+  Index m_patch_col_inflate_strides;
+
+  Index m_planePaddingTop;  // Plane padding
+  Index m_rowPaddingTop;    // Row padding
+  Index m_colPaddingLeft;   // Column padding
+
+  // Fast representation of various divisors.
+  internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+  internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+  internal::TensorIntDivisor<Index> m_fastInputRowStride;
+  internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+  internal::TensorIntDivisor<Index> m_fastRowStride;
+  internal::TensorIntDivisor<Index> m_fastColStride;
+
+  internal::TensorIntDivisor<Index> m_fastDimZero;  // aka output depth
+  internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+  internal::TensorIntDivisor<Index> m_fastOutputRows;
+  internal::TensorIntDivisor<Index> m_fastOutputCols;
+  internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+  const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+          DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+          typename Index, typename nocontract_t, typename contract_t, int Side,
+          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+          int Alignment>
+class TensorContractionSubMapper<
+    Scalar, Index, Side,
+    TensorEvaluator<const TensorReshapingOp<NewDimension,
+                                            const TensorVolumePatchOp<
+                                                Planes, Rows, Cols, ArgType> >,
+                    Device>,
+    nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+    inner_dim_reordered, Alignment> {
+ public:
+  typedef typename packet_traits<Scalar>::type Packet;
+  typedef typename packet_traits<Scalar>::half HalfPacket;
+
+  typedef TensorContractionInputMapper<
+      Scalar, Index, Side,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      ParentMapper;
+  typedef TensorContractionSubMapper<
+      Scalar, Index, Side,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      Self;
+  typedef Self LinearMapper;
+
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+      const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+      : m_base_mapper(base_mapper),
+        m_depth_offset(vert_offset),
+        m_col_offset(horiz_offset) {
+    m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+                                     m_colIndex, m_otherIndex);
+  }
+  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+      const Self& base_mapper, Index vert_offset, Index horiz_offset)
+      : m_base_mapper(base_mapper.m_base_mapper),
+        m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+        m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+    m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+                                     m_colIndex, m_otherIndex);
+  }
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+    return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+                                   m_colIndex, m_otherIndex);
+  }
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+                                                          Index j) const {
+    return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+    return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+                                    m_rowIndex, m_colIndex, m_otherIndex);
+  }
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+                                                          Index j) const {
+    return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+                                                        j + m_col_offset);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+  loadCoeffStandard(Index i) const {
+    return m_base_mapper.loadCoeffStandard(
+        i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+    return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+                                        m_rowIndex, m_colIndex, m_otherIndex);
+  }
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+  loadPacketStandard(Index i) const {
+    return m_base_mapper.loadPacketStandard(
+        i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+  }
+  template <typename Packet>
+  EIGEN_DEVICE_FUNC bool aligned(Index) const {
+    return false;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+    return m_base_mapper.nonStandardPatches();
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchDepth() const {
+    return m_base_mapper.m_patch_depth;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+    return m_base_mapper.m_patch_planes;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchRows() const {
+    return m_base_mapper.m_patch_rows;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index patchCols() const {
+    return m_base_mapper.m_patch_cols;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+                                             const Index baseIndex) const {
+    const Index inputIndex = depth + baseIndex;
+    return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+    const Index p = m_planeIndex + plane;
+    return p < 0 || p >= m_base_mapper.m_inputPlanes;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+    const Index r = m_rowIndex + row;
+    return r < 0 || r >= m_base_mapper.m_inputRows;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+    const Index c = m_colIndex + col;
+    return c < 0 || c >= m_base_mapper.m_inputCols;
+  }
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+                                      const Index col) const {
+    const Index p = m_planeIndex + plane;
+    const Index r = m_rowIndex + row;
+    const Index c = m_colIndex + col;
+    return p * m_base_mapper.m_planeInputStride +
+           r * m_base_mapper.m_rowInputStride +
+           c * m_base_mapper.m_colInputStride + m_otherIndex;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index planeOffset() const {
+    const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+    const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+    const Index rowOffset =
+        (patchOffset - colOffset * m_base_mapper.m_colStride) /
+        m_base_mapper.m_fastRowStride;
+    const Index planeOffset = patchOffset -
+                              colOffset * m_base_mapper.m_colStride -
+                              rowOffset * m_base_mapper.m_rowStride;
+    return planeOffset;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index rowOffset() const {
+    const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+    const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+    const Index rowOffset =
+        (patchOffset - colOffset * m_base_mapper.m_colStride) /
+        m_base_mapper.m_fastRowStride;
+    return rowOffset;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index colOffset() const {
+    const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+    const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+    return colOffset;
+  }
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Index depthOffset() const {
+    const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
+    return patchOffset;
+  }
+
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+  getLinearMapper(Index i, Index j) const {
+    return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+  }
+
+ private:
+  const ParentMapper& m_base_mapper;
+  Index m_depth_offset;  // First row in the input matrix
+  Index m_col_offset;    // First col in the input matrix
+
+  // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+  // indices for the first element in a patch specified by col_offset
+  // (see computeBaseIndices(...) for details).
+  Index m_planeIndex;
+  Index m_rowIndex;
+  Index m_colIndex;
+  Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ...
+// A1 B1 C1 D1 E1 F1 G1 H1 ...
+// A2 B2 C2 D2 E2 F2 G2 H2 ...
+// A3 B3 C3 D3 E3 F3 G3 H3 ...
+// A4 B4 C4 D4 E4 F4 G4 H4 ...
+// A5 B5 C5 D5 E5 F5 G5 H5 ...
+// A6 B6 C6 D6 E6 F6 G6 H6 ...
+// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A8 ...
+// ...
+//
+// Packing yields row major output (A0 beside A1 in memory):
+// A0 A1 A2 A3 A4 A5 A6 A7
+// B0 B1 B2 B3 B4 B5 B6 B7
+// C0 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) nr - number of registers along the 'n' dimension.
+//    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+//    Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+          DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+          typename Index, typename nocontract_t, typename contract_t,
+          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+          int Alignment, int nr>
+struct gemm_pack_rhs<
+    Scalar, Index,
+    TensorContractionSubMapper<
+        Scalar, Index, Rhs,
+        TensorEvaluator<const TensorReshapingOp<
+                            NewDimension, const TensorVolumePatchOp<
+                                              Planes, Rows, Cols, ArgType> >,
+                        Device>,
+        nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+        inner_dim_reordered, Alignment>,
+    nr, ColMajor, false, false> {
+  typedef TensorContractionSubMapper<
+      Scalar, Index, Rhs,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      SubMapper;
+  typedef SubMapper DataMapper;
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+                                    Index depth, Index cols, Index stride = 0,
+                                    Index offset = 0) const {
+    eigen_assert(stride == 0);
+    eigen_assert(offset == 0);
+
+    EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+    typedef typename packet_traits<Scalar>::type Packet;
+
+    const Index packet_cols4 = (cols / 4) * 4;
+    const Index peeled_k = (depth / packet_size) * packet_size;
+    const bool non_standard_patches = rhs.nonStandardPatches();
+
+    for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+      const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+      const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+      const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+      Index k = 0;
+      if ((packet_size % 4) == 0 && !non_standard_patches) {
+        const Index patch_depth = rhs.patchDepth();
+
+        if ((patch_depth % packet_size) == 0) {
+          const Index patch_cols = rhs.patchCols();
+          const Index patch_rows = rhs.patchRows();
+          const Index patch_planes = rhs.patchPlanes();
+
+          const Index startCol = rhs.colOffset();
+          const Index max_cols = std::min<Index>(
+              Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+                  startCol,
+              patch_cols);
+
+          for (Index c = startCol; c < max_cols; ++c) {
+            eigen_assert(k < peeled_k);
+
+            const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+            const Index max_rows = std::min<Index>(
+                Eigen::divup(
+                    peeled_k - c * patch_rows * patch_planes * patch_depth,
+                    patch_planes * patch_depth) +
+                    startRow,
+                patch_rows);
+
+            const bool pad_col0 = dm0.padCol(c);
+            const bool pad_col1 = dm1.padCol(c);
+            const bool pad_col2 = dm2.padCol(c);
+            const bool pad_col3 = dm3.padCol(c);
+
+            for (Index r = startRow; r < max_rows; ++r) {
+              eigen_assert(k < peeled_k);
+
+              const Index startPlane =
+                  ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+              const Index max_planes = std::min<Index>(
+                  Eigen::divup(
+                      peeled_k -
+                          c * patch_rows * patch_planes * patch_depth -  // col
+                          r * patch_planes * patch_depth,                // row
+                      patch_depth) +
+                      startPlane,
+                  patch_planes);
+
+              const bool pad_row0 = dm0.padRow(r);
+              const bool pad_row1 = dm1.padRow(r);
+              const bool pad_row2 = dm2.padRow(r);
+              const bool pad_row3 = dm3.padRow(r);
+
+              for (Index p = startPlane; p < max_planes; ++p) {
+                eigen_assert(k < peeled_k);
+
+                const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+                const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+                const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+                const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+                const Index idx0 = dm0.baseIndex(p, r, c);
+                const Index idx1 = dm1.baseIndex(p, r, c);
+                const Index idx2 = dm2.baseIndex(p, r, c);
+                const Index idx3 = dm3.baseIndex(p, r, c);
+
+                const Index startDepth =
+                    ((c == startCol) && (r == startRow) && (p == startPlane))
+                        ? rhs.depthOffset()
+                        : 0;
+                const Index max_depth = std::min<Index>(
+                    peeled_k -
+                        c * patch_rows * patch_planes * patch_depth -  // col
+                        r * patch_planes * patch_depth -               // row
+                        p * patch_depth +                              // plane
+                        startDepth,
+                    patch_depth);
+                eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+                for (Index d = startDepth; d < max_depth; d += packet_size) {
+                  eigen_assert(k < peeled_k);
+                  PacketBlock<Packet, 4> kernel;
+                  kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+                                          : rhs.packetNoPadding(d, idx0);
+                  kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+                                          : rhs.packetNoPadding(d, idx1);
+                  kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+                                          : rhs.packetNoPadding(d, idx2);
+                  kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+                                          : rhs.packetNoPadding(d, idx3);
+                  ptranspose(kernel);
+                  pstoreu(block + 0 * packet_size, kernel.packet[0]);
+                  pstoreu(block + 1 * packet_size, kernel.packet[1]);
+                  pstoreu(block + 2 * packet_size, kernel.packet[2]);
+                  pstoreu(block + 3 * packet_size, kernel.packet[3]);
+                  block += 4 * packet_size;
+                  k += packet_size;
+                }
+              }
+            }
+          }
+
+          for (; k < peeled_k; k += packet_size) {
+            PacketBlock<Packet, 4> kernel;
+            kernel.packet[0] = dm0.loadPacketFast(k);
+            kernel.packet[1] = dm1.loadPacketFast(k);
+            kernel.packet[2] = dm2.loadPacketFast(k);
+            kernel.packet[3] = dm3.loadPacketFast(k);
+            ptranspose(kernel);
+            pstoreu(block + 0 * packet_size, kernel.packet[0]);
+            pstoreu(block + 1 * packet_size, kernel.packet[1]);
+            pstoreu(block + 2 * packet_size, kernel.packet[2]);
+            pstoreu(block + 3 * packet_size, kernel.packet[3]);
+            block += 4 * packet_size;
+          }
+        } else {
+          for (; k < peeled_k; k += packet_size) {
+            PacketBlock<Packet, 4> kernel;
+            kernel.packet[0] = dm0.loadPacketStandard(k);
+            kernel.packet[1] = dm1.loadPacketStandard(k);
+            kernel.packet[2] = dm2.loadPacketStandard(k);
+            kernel.packet[3] = dm3.loadPacketStandard(k);
+            ptranspose(kernel);
+            pstoreu(block + 0 * packet_size, kernel.packet[0]);
+            pstoreu(block + 1 * packet_size, kernel.packet[1]);
+            pstoreu(block + 2 * packet_size, kernel.packet[2]);
+            pstoreu(block + 3 * packet_size, kernel.packet[3]);
+            block += 4 * packet_size;
+          }
+        }
+      }
+      if (!rhs.nonStandardPatches()) {
+        for (; k < depth; k++) {
+          block[0] = dm0.loadCoeffStandard(k);
+          block[1] = dm1.loadCoeffStandard(k);
+          block[2] = dm2.loadCoeffStandard(k);
+          block[3] = dm3.loadCoeffStandard(k);
+          block += 4;
+        }
+      } else {
+        for (; k < depth; k++) {
+          block[0] = dm0(k);
+          block[1] = dm1(k);
+          block[2] = dm2(k);
+          block[3] = dm3(k);
+          block += 4;
+        }
+      }
+    }
+
+    // copy the remaining columns one at a time (nr==1)
+    for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+      for (Index k = 0; k < depth; k++) {
+        *block = dm0(k);
+        block += 1;
+      }
+    }
+  }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+          DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+          typename Index, typename nocontract_t, typename contract_t,
+          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+          int nr>
+struct gemm_pack_rhs<
+    Scalar, Index,
+    TensorContractionSubMapper<
+        Scalar, Index, Rhs,
+        TensorEvaluator<const TensorReshapingOp<
+                            NewDimension, const TensorVolumePatchOp<
+                                              Planes, Rows, Cols, ArgType> >,
+                        Device>,
+        nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+        inner_dim_reordered, Alignment>,
+    nr, ColMajor, false, false> {
+  typedef TensorContractionSubMapper<
+      Scalar, Index, Rhs,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+      inner_dim_reordered, Alignment>
+      SubMapper;
+  typedef SubMapper DataMapper;
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+                                    Index depth, Index cols, Index stride = 0,
+                                    Index offset = 0) const {
+    eigen_assert(stride == 0);
+    eigen_assert(offset == 0);
+
+    EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+    typedef typename packet_traits<Scalar>::type Packet;
+
+    const int packet_size = 2;
+
+    const Index packet_cols4 = (cols / 4) * 4;
+    const Index peeled_k = (depth / packet_size) * packet_size;
+    const bool non_standard_patches = rhs.nonStandardPatches();
+
+    for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+      const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+      const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+      const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+      Index k = 0;
+      if (!non_standard_patches) {
+        const Index patch_depth = rhs.patchDepth();
+
+        if ((patch_depth % packet_size) == 0) {
+          const Index patch_cols = rhs.patchCols();
+          const Index patch_rows = rhs.patchRows();
+          const Index patch_planes = rhs.patchPlanes();
+
+          const Index startCol = rhs.colOffset();
+          const Index max_cols = std::min<Index>(
+              Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
+                  startCol,
+              patch_cols);
+
+          for (Index c = startCol; c < max_cols; ++c) {
+            eigen_assert(k < peeled_k);
+
+            const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
+            const Index max_rows = std::min<Index>(
+                Eigen::divup(
+                    peeled_k - c * patch_rows * patch_planes * patch_depth,
+                    patch_planes * patch_depth) +
+                    startRow,
+                patch_rows);
+
+            const bool pad_col0 = dm0.padCol(c);
+            const bool pad_col1 = dm1.padCol(c);
+            const bool pad_col2 = dm2.padCol(c);
+            const bool pad_col3 = dm3.padCol(c);
+
+            for (Index r = startRow; r < max_rows; ++r) {
+              eigen_assert(k < peeled_k);
+
+              const Index startPlane =
+                  ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
+              const Index max_planes = std::min<Index>(
+                  Eigen::divup(
+                      peeled_k -
+                          c * patch_rows * patch_planes * patch_depth -  // col
+                          r * patch_planes * patch_depth,                // row
+                      patch_depth) +
+                      startPlane,
+                  patch_planes);
+
+              const bool pad_row0 = dm0.padRow(r);
+              const bool pad_row1 = dm1.padRow(r);
+              const bool pad_row2 = dm2.padRow(r);
+              const bool pad_row3 = dm3.padRow(r);
+
+              for (Index p = startPlane; p < max_planes; ++p) {
+                eigen_assert(k < peeled_k);
+
+                const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+                const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+                const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+                const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+                const Index idx0 = dm0.baseIndex(p, r, c);
+                const Index idx1 = dm1.baseIndex(p, r, c);
+                const Index idx2 = dm2.baseIndex(p, r, c);
+                const Index idx3 = dm3.baseIndex(p, r, c);
+
+                const Index startDepth =
+                    ((c == startCol) && (r == startRow) && (p == startPlane))
+                        ? rhs.depthOffset()
+                        : 0;
+                const Index max_depth = std::min<Index>(
+                    peeled_k -
+                        c * patch_rows * patch_planes * patch_depth -  // col
+                        r * patch_planes * patch_depth -               // row
+                        p * patch_depth +                              // plane
+                        startDepth,
+                    patch_depth);
+                eigen_assert((max_depth - startDepth) % packet_size == 0);
+
+                for (Index d = startDepth; d < max_depth; d += packet_size) {
+                  eigen_assert(k < peeled_k);
+                  PacketBlock<Packet, 2> kernel0;
+                  PacketBlock<Packet, 2> kernel1;
+                  kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+                                           : rhs.packetNoPadding(d, idx0);
+                  kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+                                           : rhs.packetNoPadding(d, idx1);
+                  kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+                                           : rhs.packetNoPadding(d, idx2);
+                  kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+                                           : rhs.packetNoPadding(d, idx3);
+                  ptranspose(kernel0);
+                  ptranspose(kernel1);
+                  pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+                  pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+                  pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+                  pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+                  block += 4 * packet_size;
+                  k += packet_size;
+                }
+              }
+            }
+          }
+
+          for (; k < peeled_k; k += packet_size) {
+            PacketBlock<Packet, 2> kernel0;
+            PacketBlock<Packet, 2> kernel1;
+            kernel0.packet[0] = dm0.loadPacketFast(k);
+            kernel0.packet[1] = dm1.loadPacketFast(k);
+            kernel1.packet[0] = dm2.loadPacketFast(k);
+            kernel1.packet[1] = dm3.loadPacketFast(k);
+            ptranspose(kernel0);
+            ptranspose(kernel1);
+            pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+            pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+            pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+            pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+            block += 4 * packet_size;
+          }
+        } else {
+          for (; k < peeled_k; k += packet_size) {
+            PacketBlock<Packet, 2> kernel0;
+            PacketBlock<Packet, 2> kernel1;
+            kernel0.packet[0] = dm0.loadPacketStandard(k);
+            kernel0.packet[1] = dm1.loadPacketStandard(k);
+            kernel1.packet[0] = dm2.loadPacketStandard(k);
+            kernel1.packet[1] = dm3.loadPacketStandard(k);
+            ptranspose(kernel0);
+            ptranspose(kernel1);
+            pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+            pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+            pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+            pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+            block += 4 * packet_size;
+          }
+        }
+      }
+      if (!rhs.nonStandardPatches()) {
+        for (; k < depth; k++) {
+          block[0] = dm0.loadCoeffStandard(k);
+          block[1] = dm1.loadCoeffStandard(k);
+          block[2] = dm2.loadCoeffStandard(k);
+          block[3] = dm3.loadCoeffStandard(k);
+          block += 4;
+        }
+      } else {
+        for (; k < depth; k++) {
+          block[0] = dm0(k);
+          block[1] = dm1(k);
+          block[2] = dm2(k);
+          block[3] = dm3(k);
+          block += 4;
+        }
+      }
+    }
+
+    // copy the remaining columns one at a time (nr==1)
+    for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+      for (Index k = 0; k < depth; k++) {
+        *block = dm0(k);
+        block += 1;
+      }
+    }
+  }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+          DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+          typename Index, typename nocontract_t, typename contract_t,
+          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+          int nr>
+struct gemm_pack_rhs<
+    Scalar, Index,
+    TensorContractionSubMapper<
+        Scalar, Index, Rhs,
+        TensorEvaluator<const TensorReshapingOp<
+                            NewDimension, const TensorVolumePatchOp<
+                                              Planes, Rows, Cols, ArgType> >,
+                        Device>,
+        nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+        inner_dim_reordered, Alignment>,
+    nr, ColMajor, false, false> {
+  typedef TensorContractionSubMapper<
+      Scalar, Index, Rhs,
+      TensorEvaluator<const TensorReshapingOp<
+                          NewDimension, const TensorVolumePatchOp<
+                                            Planes, Rows, Cols, ArgType> >,
+                      Device>,
+      nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+      Alignment>
+      SubMapper;
+  typedef SubMapper DataMapper;
+
+  EIGEN_DEVICE_FUNC
+  EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+                                    Index depth, Index cols, Index stride = 0,
+                                    Index offset = 0) const {
+    eigen_assert(stride == 0);
+    eigen_assert(offset == 0);
+
+    EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+    const Index packet_cols4 = (cols / 4) * 4;
+
+    for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+      const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+      const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+      const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+      if (!rhs.nonStandardPatches()) {
+        for (Index k = 0; k < depth; k++) {
+          block[0] = dm0.loadCoeffStandard(k);
+          block[1] = dm1.loadCoeffStandard(k);
+          block[2] = dm2.loadCoeffStandard(k);
+          block[3] = dm3.loadCoeffStandard(k);
+          block += 4;
+        }
+      } else {
+        for (Index k = 0; k < depth; k++) {
+          block[0] = dm0(k);
+          block[1] = dm1(k);
+          block[2] = dm2(k);
+          block[3] = dm3(k);
+          block += 4;
+        }
+      }
+    }
+
+    // copy the remaining columns one at a time (nr==1)
+    for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+      const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+      for (Index k = 0; k < depth; k++) {
+        *block = dm0(k);
+        block += 1;
+      }
+    }
+  }
+};
+
+}  // namespace internal
+
 /** CuboidConvolution
  * \ingroup CXX11_NeuralNetworks_Module
  *
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index cd2873b..7710cf9 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -21,6 +21,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
 #include "tensorflow/core/kernels/bounds_check.h"
 #include "tensorflow/core/platform/prefetch.h"
 #include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 66ae7f0..277ee2b 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -123,10 +123,10 @@
 // is considerably more efficient.
 #pragma omp parallel for
     for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
-      const Eigen::array<Eigen::DenseIndex, 1> loc = i;
+      const Eigen::array<Eigen::DenseIndex, 1> loc{i};
       gather_nd_generator(loc);
     }
-#else
+#else  // INTEL_MKL
     Tscratch.device(d) = Tscratch.reshape(reshape_dims)
                              .broadcast(broadcast_dims)
                              .generate(gather_nd_generator)
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index c7dbefa..86146f7 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -123,8 +123,7 @@
   string GetActionSummary(StringPiece action, const Parameters& params,
                           const Config& config) {
     return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
-                           std::string(action).c_str(),
-                           params.ToString().c_str(),
+                           string(action).c_str(), params.ToString().c_str(),
                            config.ToString().c_str());
   }
 
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index bca1cff..2088c13 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@
   return Status::OK();
 }
 
-#define REGISTER_LIST_COPY(DIRECTION)                   \
-  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
-      TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION)                                         \
+  INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+                                                       TensorListDeviceCopy)
 
 REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
 REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@
   return Status::OK();
 }
 
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
-                                      TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
 
 bool TensorList::Decode(const VariantTensorData& data) {
   tensors = data.tensors();
@@ -625,12 +624,11 @@
 #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
 
 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
-                                          TensorList, TensorList::kTypeName,
+                                          TensorList,
                                           TensorListBinaryAdd<CPUDevice>);
 
 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
                                          DEVICE_CPU, TensorList,
-                                         TensorList::kTypeName,
                                          TensorListZerosLike<CPUDevice>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index c591226..a00bf70 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -94,11 +94,10 @@
 #undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
 
 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
-                                          TensorList, TensorList::kTypeName,
+                                          TensorList,
                                           TensorListBinaryAdd<GPUDevice>);
 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
                                          DEVICE_GPU, TensorList,
-                                         TensorList::kTypeName,
                                          TensorListZerosLike<GPUDevice>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 066a1d6..72581c9 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -374,7 +374,12 @@
   y->tensors.reserve(x.tensors.size());
   for (const Tensor& t : x.tensors) {
     Tensor out_tensor;
-    TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
+    AllocatorAttributes attr;
+    if (t.dtype() == DT_VARIANT) {
+      attr.set_on_host(true);
+    }
+    TF_RETURN_IF_ERROR(
+        c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
     switch (out_tensor.dtype()) {
 #define DTYPE_CASE(dtype)                                        \
   case DataTypeToEnum<dtype>::value:                             \
@@ -385,6 +390,20 @@
       TF_CALL_POD_TYPES(DTYPE_CASE)
 
 #undef DTYPE_CASE
+
+      case DataTypeToEnum<Variant>::value: {
+        const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
+        if (inner_x == nullptr) {
+          return errors::InvalidArgument("Input handle is not a list. Saw: '",
+                                         t.scalar<Variant>()().DebugString(),
+                                         "'");
+        }
+        TensorList inner_y;
+        TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
+        out_tensor.scalar<Variant>()() = std::move(inner_y);
+        break;
+      }
+
       default:
         return errors::InvalidArgument(
             "Trying to compute zeros_like for unsupported dtype ",
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 2e8d9c6..a495758 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@
   MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
 
   size_t size() const override {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     return table_.size();
   }
 
@@ -60,7 +60,7 @@
     const auto key_values = key.flat<K>();
     auto value_values = value->flat<V>();
 
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     for (int64 i = 0; i < key_values.size(); ++i) {
       value_values(i) = gtl::FindWithDefault(
           table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@
   }
 
   Status ExportValues(OpKernelContext* ctx) override {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     int64 size = table_.size();
 
     Tensor* keys;
@@ -125,7 +125,7 @@
 
   int64 MemoryUsed() const override {
     int64 ret = 0;
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
       size_t bucket_size = table_.bucket_size(i);
       if (bucket_size == 0) {
@@ -138,7 +138,6 @@
   }
 
  private:
-  // TODO(andreasst): consider using a read/write lock or a concurrent map
   mutable mutex mu_;
   std::unordered_map<K, V> table_ GUARDED_BY(mu_);
 };
@@ -158,7 +157,7 @@
   }
 
   size_t size() const override {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     return table_.size();
   }
 
@@ -169,7 +168,7 @@
     auto value_values = value->flat_inner_dims<V, 2>();
     int64 value_dim = value_shape_.dim_size(0);
 
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     for (int64 i = 0; i < key_values.size(); ++i) {
       ValueArray* value_vec =
           gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@
   }
 
   Status ExportValues(OpKernelContext* ctx) override {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     int64 size = table_.size();
     int64 value_dim = value_shape_.dim_size(0);
 
@@ -254,7 +253,7 @@
 
   int64 MemoryUsed() const override {
     int64 ret = 0;
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     for (unsigned i = 0; i < table_.bucket_count(); ++i) {
       size_t bucket_size = table_.bucket_size(i);
       if (bucket_size == 0) {
@@ -268,7 +267,6 @@
 
  private:
   TensorShape value_shape_;
-  // TODO(andreasst): consider using a read/write lock or a concurrent map
   mutable mutex mu_;
   typedef gtl::InlinedVector<V, 4> ValueArray;
   std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,7 +333,7 @@
   }
 
   size_t size() const override LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     return num_entries_;
   }
 
@@ -355,7 +353,7 @@
     auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
     const auto default_flat = default_value.flat<V>();
 
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     const auto key_buckets_matrix =
         key_buckets_.AccessTensor(ctx)->template matrix<K>();
     const auto value_buckets_matrix =
@@ -451,7 +449,7 @@
   }
 
   Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
     Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
     TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -493,7 +491,7 @@
   TensorShape value_shape() const override { return value_shape_; }
 
   int64 MemoryUsed() const override {
-    mutex_lock l(mu_);
+    tf_shared_lock l(mu_);
     return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
            value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
   }
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index bdc3b57..dd89597 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -410,8 +410,9 @@
         copy_or_move_tensors(&it->second, *key, *indices, tuple));
 
     // Remove entry if all the values have been consumed
-    if (!std::any_of(it->second.begin(), it->second.end(),
-                     std::mem_fn(&OptionalTensor::has_value))) {
+    if (!std::any_of(
+            it->second.begin(), it->second.end(),
+            [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
       map_.erase(it);
     }
 
@@ -444,8 +445,9 @@
     *key = it->first;
 
     // Remove entry if all the values have been consumed
-    if (!std::any_of(it->second.begin(), it->second.end(),
-                     std::mem_fn(&OptionalTensor::has_value))) {
+    if (!std::any_of(
+            it->second.begin(), it->second.end(),
+            [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
       map_.erase(it);
     }
 
diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
index 10e468c..693ed8a 100644
--- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
+++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
@@ -114,9 +114,7 @@
     // Exercises "delete_old_dirs".
     for (int i = 0; i < 2; ++i) {
       int directory_found =
-          Env::Default()
-              ->IsDirectory(std::string(io::Dirname(prefixes[i])))
-              .code();
+          Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code();
       if (delete_old_dirs) {
         EXPECT_EQ(error::NOT_FOUND, directory_found);
       } else {
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 9b10c3f..184e0cb 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -1083,7 +1083,7 @@
 #endif
 
 // Register 2D operations
-#define REGISTER_MKL_CPU(T)                                         \
+#define REGISTER_MKL_CPU_2D(T)                                      \
   REGISTER_KERNEL_BUILDER(Name("_MklConv2D")                        \
                               .Device(DEVICE_CPU)                   \
                               .TypeConstraint<T>("T")               \
@@ -1100,16 +1100,16 @@
                               .Label(mkl_op_registry::kMklOpLabel), \
                           MklDummyOp<CPUDevice, T>);
 
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_2D);
 
 // Register 3D operations
-#define REGISTER_MKL_CPU(T)                                         \
+#define REGISTER_MKL_CPU_3D(T)                                      \
   REGISTER_KERNEL_BUILDER(Name("_MklConv3D")                        \
                               .Device(DEVICE_CPU)                   \
                               .TypeConstraint<T>("T")               \
                               .Label(mkl_op_registry::kMklOpLabel), \
                           MklConvOp<CPUDevice, T, false>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_3D);
 
 }  // namespace tensorflow
 #endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index ec6d241..5398e61 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -34,11 +34,11 @@
 
 template <typename T>
 void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
-  if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
-      fwdParams.alg_kind != pooling_avg_include_padding &&
-      fwdParams.alg_kind != pooling_avg_exclude_padding) {
-    assert("Pooling algorithm kind is not supported\n");
-  }
+  DCHECK(fwdParams.alg_kind == pooling_max ||
+         fwdParams.alg_kind == pooling_avg ||
+         fwdParams.alg_kind == pooling_avg_include_padding ||
+         fwdParams.alg_kind == pooling_avg_exclude_padding)
+      << "Pooling algorithm kind is not supported";
 
   context_.alg_kind = fwdParams.alg_kind;
   // create memory desc
@@ -102,7 +102,7 @@
       static_cast<void*>(const_cast<T*>(src_data)));
   context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
   if (context_.alg_kind == pooling_max) {  // max pooling must have ws
-    assert(ws_data != nullptr);
+    DCHECK(ws_data != nullptr);
     context_.ws_mem->set_data_handle(ws_data);
   }
   context_.fwd_stream->submit(context_.fwd_primitives);
@@ -111,7 +111,7 @@
   context_.src_mem->set_data_handle(DummyData);
   context_.dst_mem->set_data_handle(DummyData);
   if (context_.alg_kind == pooling_max) {  // max pooling must have ws
-    assert(ws_data != nullptr);
+    DCHECK(ws_data != nullptr);
     context_.ws_mem->set_data_handle(DummyData);
   }
 }
@@ -120,11 +120,11 @@
 
 template <typename T>
 void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
-  if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
-      bwdParams.alg_kind != pooling_avg_include_padding &&
-      bwdParams.alg_kind != pooling_avg_exclude_padding) {
-    assert("Pooling algorithm kind is not supported\n");
-  }
+  DCHECK(bwdParams.alg_kind == pooling_max ||
+         bwdParams.alg_kind == pooling_avg ||
+         bwdParams.alg_kind == pooling_avg_include_padding ||
+         bwdParams.alg_kind == pooling_avg_exclude_padding)
+      << "Pooling algorithm kind is not supported";
   context_.alg_kind = bwdParams.alg_kind;
 
   // check whether it is 2d or 3d
@@ -190,7 +190,7 @@
       static_cast<void*>(const_cast<T*>(diff_dst_data)));
   context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
   if (context_.alg_kind == pooling_max) {
-    assert(ws_data != nullptr);
+    DCHECK(ws_data != nullptr);
     context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
   }
 
@@ -199,7 +199,7 @@
   context_.diff_dst_mem->set_data_handle(DummyData);
   context_.diff_src_mem->set_data_handle(DummyData);
   if (context_.alg_kind == pooling_max) {
-    assert(ws_data != nullptr);
+    DCHECK(ws_data != nullptr);
     context_.ws_mem->set_data_handle(DummyData);
   }
 }
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index f4cfc48..8438535 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -40,7 +40,6 @@
 #include "mkl_dnn.h"
 #include "mkl_dnn_types.h"
 #endif
-#include "tensorflow/core/platform/default/logging.h"
 #include "tensorflow/core/util/mkl_util.h"
 
 namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 04d8a1b..cfab529 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -88,6 +88,7 @@
           break;
         default:
           OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+          return;
       }
       // Create softmax memory for src, dst: both are defined in mkl_util.h,
       // they are wrapper
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 5d9257e..81ce6d6 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -75,28 +75,28 @@
 }
 
 // Return intersection-over-union overlap between boxes i and j
-static inline float IOUGreaterThanThreshold(
-    typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
-    float iou_threshold) {
-  const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
-  const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
-  const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
-  const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
-  const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
-  const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
-  const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
-  const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
-  const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
-  const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
-  if (area_i <= 0 || area_j <= 0) return 0.0;
-  const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
-  const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
-  const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
-  const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
-  const float intersection_area =
-      std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
-      std::max<float>(intersection_xmax - intersection_xmin, 0.0);
-  const float iou = intersection_area / (area_i + area_j - intersection_area);
+template <typename T>
+static inline bool IOUGreaterThanThreshold(
+    typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
+  const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
+  const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
+  const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
+  const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
+  const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
+  const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
+  const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
+  const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
+  const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+  const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+  if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
+  const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
+  const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
+  const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
+  const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
+  const T intersection_area =
+      std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
+      std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
+  const T iou = intersection_area / (area_i + area_j - intersection_area);
   return iou > iou_threshold;
 }
 
@@ -106,11 +106,13 @@
   return overlaps(i, j) > overlap_threshold;
 }
 
+template <typename T>
 static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
     const Tensor& boxes, float threshold) {
-  typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
-  return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
-                   std::placeholders::_2, threshold);
+  typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
+  return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
+                   std::placeholders::_1, std::placeholders::_2,
+                   static_cast<T>(threshold));
 }
 
 static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
@@ -121,6 +123,7 @@
                    std::placeholders::_1, std::placeholders::_2, threshold);
 }
 
+template <typename T>
 void DoNonMaxSuppressionOp(
     OpKernelContext* context, const Tensor& scores, int num_boxes,
     const Tensor& max_output_size, const float score_threshold,
@@ -128,13 +131,13 @@
     bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
   const int output_size = max_output_size.scalar<int>()();
 
-  std::vector<float> scores_data(num_boxes);
-  std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+  std::vector<T> scores_data(num_boxes);
+  std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
 
   // Data structure for selection candidate in NMS.
   struct Candidate {
     int box_index;
-    float score;
+    T score;
   };
 
   auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
@@ -143,13 +146,13 @@
   std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
       candidate_priority_queue(cmp);
   for (int i = 0; i < scores_data.size(); ++i) {
-    if (scores_data[i] > score_threshold) {
+    if (static_cast<float>(scores_data[i]) > score_threshold) {
       candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
     }
   }
 
   std::vector<int> selected;
-  std::vector<float> selected_scores;
+  std::vector<T> selected_scores;
   Candidate next_candidate;
 
   while (selected.size() < output_size && !candidate_priority_queue.empty()) {
@@ -176,7 +179,7 @@
   int num_valid_outputs = selected.size();
   if (pad_to_max_output_size) {
     selected.resize(output_size, 0);
-    selected_scores.resize(output_size, 0);
+    selected_scores.resize(output_size, static_cast<T>(0));
   }
   if (ptr_num_valid_outputs) {
     *ptr_num_valid_outputs = num_valid_outputs;
@@ -221,18 +224,19 @@
     if (!context->status().ok()) {
       return;
     }
-    auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+    auto suppress_check_fn =
+        CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
 
     const float score_threshold_val = std::numeric_limits<float>::lowest();
-    DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
-                          score_threshold_val, suppress_check_fn);
+    DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+                                 score_threshold_val, suppress_check_fn);
   }
 
  private:
   float iou_threshold_;
 };
 
-template <typename Device>
+template <typename Device, typename T>
 class NonMaxSuppressionV2Op : public OpKernel {
  public:
   explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
@@ -264,11 +268,12 @@
     if (!context->status().ok()) {
       return;
     }
-    auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+    auto suppress_check_fn =
+        CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
 
     const float score_threshold_val = std::numeric_limits<float>::lowest();
-    DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
-                          score_threshold_val, suppress_check_fn);
+    DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
+                             score_threshold_val, suppress_check_fn);
   }
 };
 
@@ -325,7 +330,7 @@
   float score_threshold_val_;
 };
 
-template <typename Device>
+template <typename Device, typename T>
 class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
  public:
   explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
@@ -334,14 +339,14 @@
  protected:
   void DoComputeAndPostProcess(OpKernelContext* context) override {
     auto suppress_check_fn =
-        CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+        CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
 
-    DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
-                          score_threshold_val_, suppress_check_fn);
+    DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+                             score_threshold_val_, suppress_check_fn);
   }
 };
 
-template <typename Device>
+template <typename Device, typename T>
 class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
  public:
   explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
@@ -353,12 +358,12 @@
  protected:
   void DoComputeAndPostProcess(OpKernelContext* context) override {
     auto suppress_check_fn =
-        CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+        CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
     int num_valid_outputs;
 
-    DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
-                          score_threshold_val_, suppress_check_fn,
-                          pad_to_max_output_size_, &num_valid_outputs);
+    DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+                             score_threshold_val_, suppress_check_fn,
+                             pad_to_max_output_size_, &num_valid_outputs);
 
     // Allocate scalar output tensor for number of indices computed.
     Tensor* num_outputs_t = nullptr;
@@ -413,22 +418,37 @@
     auto suppress_check_fn =
         CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
 
-    DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
-                          score_threshold_val, suppress_check_fn);
+    DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+                                 score_threshold_val, suppress_check_fn);
   }
 };
 
 REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
                         NonMaxSuppressionOp<CPUDevice>);
 
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
-                        NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+    Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
+    NonMaxSuppressionV2Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
+                            .TypeConstraint<Eigen::half>("T")
+                            .Device(DEVICE_CPU),
+                        NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
 
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
-                        NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+    Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
+    NonMaxSuppressionV3Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
+                            .TypeConstraint<Eigen::half>("T")
+                            .Device(DEVICE_CPU),
+                        NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
 
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
-                        NonMaxSuppressionV4Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+    Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
+    NonMaxSuppressionV4Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
+                            .TypeConstraint<Eigen::half>("T")
+                            .Device(DEVICE_CPU),
+                        NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
 
 REGISTER_KERNEL_BUILDER(
     Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 876a170..fc1c900 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
 #include "tensorflow/core/common_runtime/placer.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
 #include "tensorflow/core/framework/function.h"
@@ -104,13 +105,6 @@
         for (auto d : lib->device_mgr()->ListDevices()) {
           device_set.AddDevice(d);
         }
-        Placer placer(graph.get(), &device_set);
-        OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
-        std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
-        OP_REQUIRES_OK_ASYNC(
-            ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
-            done);
 
         // The FunctionLibraryRuntime's library cannot be mutated from within
         // an OpKernel, so functions are instantiated in an overlay library.
@@ -124,6 +118,47 @@
             new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
         overlay_libs_.emplace(lib, overlay_lib);
 
+        GraphOptimizationPassOptions optimization_options;
+        // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+        // make it possible to specify the relevant options via attributes.
+        SessionOptions session_options;
+        session_options.env = ctx->env();
+        optimization_options.session_options = &session_options;
+        optimization_options.graph = &graph;
+        optimization_options.flib_def = overlay_lib;
+        optimization_options.device_set = &device_set;
+        OP_REQUIRES_OK_ASYNC(
+            ctx,
+            OptimizationPassRegistry::Global()->RunGrouping(
+                OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+            done);
+        Placer placer(graph.get(), &device_set);
+        OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+        OP_REQUIRES_OK_ASYNC(
+            ctx,
+            OptimizationPassRegistry::Global()->RunGrouping(
+                OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+            done);
+        OP_REQUIRES_OK_ASYNC(
+            ctx,
+            OptimizationPassRegistry::Global()->RunGrouping(
+                OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+                optimization_options),
+            done);
+
+        std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+        OP_REQUIRES_OK_ASYNC(
+            ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+            done);
+        optimization_options.graph = nullptr;
+        optimization_options.device_set = nullptr;
+        optimization_options.partition_graphs = &subgraphs;
+        OP_REQUIRES_OK_ASYNC(ctx,
+                             OptimizationPassRegistry::Global()->RunGrouping(
+                                 OptimizationPassRegistry::POST_PARTITIONING,
+                                 optimization_options),
+                             done);
+
         auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
         for (const auto& pair : subgraphs) {
           // TODO(akshayka): Fail gracefully if the set of devices corresponds
@@ -175,7 +210,7 @@
         TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
         DataType dtype = attr_value->type();
         if (dtype == DT_RESOURCE) {
-          ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+          const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
           node->set_assigned_device_name(handle.device());
         }
       }
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d4042..97ddc85 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@
   }
 
   void Compute(OpKernelContext* context) override {
-    ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+    const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
     handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
     handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
     context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc
index 5318d8c..e4ca89e 100644
--- a/tensorflow/core/kernels/reduction_ops_sum.cc
+++ b/tensorflow/core/kernels/reduction_ops_sum.cc
@@ -76,7 +76,15 @@
         .HostMemory("output")
         .HostMemory("reduction_indices"),
     ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>);
-
+REGISTER_KERNEL_BUILDER(
+    Name("Sum")
+        .Device(DEVICE_GPU)
+        .TypeConstraint<int64>("T")
+        .TypeConstraint<int32>("Tidx")
+        .HostMemory("input")
+        .HostMemory("output")
+        .HostMemory("reduction_indices"),
+    ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>);
 #endif
 
 #ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c..7edaaad 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
 
 namespace tensorflow {
 
@@ -56,4 +57,36 @@
 REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
                         RegexFullMatchOp);
 
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+  explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    string pattern;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+    re_ = MakeUnique<RE2>(pattern);
+    OP_REQUIRES(ctx, re_->ok(),
+                errors::InvalidArgument("Invalid pattern: ", pattern,
+                                        ", error: ", re_->error()));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* input_tensor;
+    OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+    const auto& input_flat = input_tensor->flat<string>();
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+                                             &output_tensor));
+    auto output_flat = output_tensor->flat<bool>();
+    for (size_t i = 0; i < input_flat.size(); ++i) {
+      output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+    }
+  }
+
+ private:
+  std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+                        StaticRegexFullMatchOp);
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 194a711..26f107f 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -47,7 +47,7 @@
   std::unordered_set<string> retval;
   for (const string& node_name_and_port : node_names_and_ports) {
     const TensorId tid = ParseTensorName(node_name_and_port);
-    retval.emplace(std::string(tid.first));
+    retval.emplace(tid.first);
   }
   return retval;
 }
@@ -64,7 +64,7 @@
 const NodeDef* FindNodeDefByName(const string& input,
                                  const GraphDef& graph_def) {
   const TensorId tid = ParseTensorName(input);
-  const string name = std::string(tid.first);
+  const string name = string(tid.first);
   for (const NodeDef& node_def : graph_def.node()) {
     if (node_def.name() == name) {
       return &node_def;
@@ -423,7 +423,7 @@
   std::vector<DataType> data_types;
   std::vector<TensorShape> shapes;
   const TensorId tid = ParseTensorName(name_and_port);
-  const string node_name = std::string(tid.first);
+  const string node_name(tid.first);
   const int port = tid.second;
   const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
   CHECK_NOTNULL(node_def);
@@ -522,8 +522,7 @@
     const TensorShapeMap& tensor_shape_map, const string& node_name) {
   if (node_name.find(':') != string::npos) {
     const TensorId tid = ParseTensorName(node_name);
-    return GetTensorShapeType(tensor_shape_map, std::string(tid.first),
-                              tid.second);
+    return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
   } else {
     return GetTensorShapeType(tensor_shape_map, node_name, 0);
   }
@@ -570,7 +569,7 @@
   const TensorId tid = ParseTensorName(name);
   CHECK_EQ(tensor_shape_map->count(name), 0);
   tensor_shape_map->emplace(
-      std::string(tid.first),
+      string(tid.first),
       std::make_pair(tid.second,
                      std::make_pair(tensor.dtype(), tensor.shape())));
 }
@@ -692,7 +691,7 @@
   std::vector<NodeBuilder::NodeOut> node_out_list;
   for (const string& input : inputs) {
     const TensorId tid = ParseTensorName(input);
-    Node* node = FindMutableNodeByName(std::string(tid.first), graph);
+    Node* node = FindMutableNodeByName(string(tid.first), graph);
     CHECK_NOTNULL(node);
     node_out_list.emplace_back(node, tid.second);
   }
@@ -848,7 +847,7 @@
 
   for (const string& subgraph_input : std::get<1>(cluster)) {
     const TensorId tid = ParseTensorName(subgraph_input);
-    const string subgraph_input_name = std::string(tid.first);
+    const string subgraph_input_name(tid.first);
     const int subgraph_input_port = tid.second;
     const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
     CHECK_NOTNULL(node_def);
@@ -895,7 +894,7 @@
   std::deque<const Node*> queue;
   for (const string& output : border_outputs) {
     const TensorId tid = ParseTensorName(output);
-    const string& output_node_name = std::string(tid.first);
+    const string output_node_name(tid.first);
     for (const Node* node : graph.nodes()) {
       if (output_node_name == node->name()) {
         queue.push_back(node);
@@ -975,7 +974,7 @@
       for (int j = 0; j < border_outputs.size(); ++j) {
         const string& output = border_outputs.at(j);
         const TensorId tid = ParseTensorName(output);
-        const string output_name = std::string(tid.first);
+        const string output_name(tid.first);
         Node* src_node = edge->src();
         if (src_node != nullptr && src_node->name() == output_name &&
             edge->src_output() == tid.second) {
@@ -995,12 +994,11 @@
   // RemoteFusedGraphExecuteOpNode
   for (const string& output : outputs) {
     const TensorId output_tid = ParseTensorName(output);
-    const string output_name = std::string(output_tid.first);
+    const string output_name(output_tid.first);
     for (size_t i = 0; i < border_outputs.size(); ++i) {
       const TensorId subgraph_output_tid =
           ParseTensorName(border_outputs.at(i));
-      const string& subgraph_output_name =
-          std::string(subgraph_output_tid.first);
+      const string subgraph_output_name(subgraph_output_tid.first);
       if (output_name == subgraph_output_name) {
         LOG(INFO) << "As graph output and subgraph output are same, "
                   << "the graph output node is replaced by identity node";
@@ -1435,7 +1433,7 @@
     GraphDef* graph_def) {
   const TensorId tid = ParseTensorName(input);
   CHECK_EQ(0, tid.second);
-  const string node_name = std::string(tid.first);
+  const string node_name(tid.first);
   for (NodeDef& node : *graph_def->mutable_node()) {
     if (node.name() != node_name) {
       continue;
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index ebcfb67..26705a8 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -79,7 +79,7 @@
 
 void ReadVariableOp::Compute(OpKernelContext* ctx) {
   Var* variable = nullptr;
-  ResourceHandle handle = HandleFromInput(ctx, 0);
+  const ResourceHandle& handle = HandleFromInput(ctx, 0);
   const auto status = LookupResource(ctx, handle, &variable);
   OP_REQUIRES(ctx, status.ok(),
               errors::FailedPrecondition(
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a..cded417 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@
   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
                                       "), ", "(", seq_lens.NumElements(),
-                                      " vs. ", input.dim_size(batch_dim)));
+                                      " vs. ", input.dim_size(batch_dim), ")"));
 
   for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
     OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@
   OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
               errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
                                       "), ", "(", seq_lens.NumElements(),
-                                      " vs. ", input.dim_size(batch_dim)));
+                                      " vs. ", input.dim_size(batch_dim), ")"));
 }
 
 template <>
@@ -127,6 +127,7 @@
     auto seq_lens_t = seq_lens.vec<Tlen>();
 
     CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+    if (!context->status().ok()) return;
 
     const int input_dims = input.dims();
 
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index e335e38..82546d5 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -161,9 +161,12 @@
   // If we cannot find a cached reader we will allocate our own.
   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
 
-  const checkpoint::TensorSliceReader* reader =
-      context->slice_reader_cache()->GetReader(file_pattern, open_func,
-                                               preferred_shard);
+  const checkpoint::TensorSliceReader* reader = nullptr;
+
+  if (context->slice_reader_cache()) {
+    reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
+                                                      preferred_shard);
+  }
   if (!reader) {
     allocated_reader.reset(new checkpoint::TensorSliceReader(
         file_pattern, open_func, preferred_shard));
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ab4de6c..180eb3c 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -220,9 +220,9 @@
         context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
 
     if (delete_old_dirs_) {
-      const string& merged_dir = std::string(io::Dirname(merged_prefix));
+      const string merged_dir(io::Dirname(merged_prefix));
       for (const string& input_prefix : input_prefixes) {
-        const string& dirname = std::string(io::Dirname(input_prefix));
+        const string dirname(io::Dirname(input_prefix));
         if (dirname == merged_dir) continue;
         Status status = env->DeleteDir(dirname);
         // For sharded save, only the first delete will go through and all
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590a..30cb1e0 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -60,8 +61,7 @@
 
 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
 
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
-                                      GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
 
 static void ExpectHasError(const Status& s, StringPiece substr) {
   EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@
     Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
     EXPECT_FALSE(s.ok());
     ExpectHasError(
-        s,
-        "No unary variant shape function found for Variant type_name: "
-        "NO KNOWN SHAPE");
+        s, strings::StrCat(
+               "No unary variant shape function found for Variant type_index: ",
+               port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
   }
 
   {
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 11149c4..a4453bd 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -50,10 +50,10 @@
  public:
   SparseConditionalAccumulator(const DataType& dtype,
                                const PartialTensorShape& shape,
-                               const string& name)
+                               const string& name, const string& reduction_type)
       : TypedConditionalAccumulatorBase<
             std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
-            dtype, shape, name) {
+            dtype, shape, name, reduction_type) {
     accum_idx_vec_ = nullptr;
     count_element_ = nullptr;
     accum_val_ = nullptr;
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1..1e542a2 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@
   Creator GetCreator() const override {
     return [this](ConditionalAccumulatorBase** ret) {
       SparseConditionalAccumulator<Device, T>* accumulator =
-          new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
-                                                      cinfo_.name());
+          new SparseConditionalAccumulator<Device, T>(
+              dtype_, shape_, cinfo_.name(), reduction_type_);
       *ret = accumulator;
       return Status::OK();
     };
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c53..11db72b 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@
   void ComputeEasyCases(OpKernelContext* context, bool* done) {
     const Tensor& input = context->input(1);
     const TensorShape& input_shape = input.shape();
-    const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+    const Tensor& split_dim_tensor = context->input(0);
+    OP_REQUIRES(
+        context, split_dim_tensor.shape().dims() == 0,
+        errors::InvalidArgument("split_dim must be a scalar but has rank ",
+                                split_dim_tensor.shape().dims()));
+    const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
     const int32 split_dim =
         split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
     const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f6..add4afa 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@
 };
 
 Status GetStack(OpKernelContext* ctx, Stack** stack) {
-  string key;
   if (ctx->input_dtype(0) == DT_RESOURCE) {
-    auto resource = ctx->input(0).flat<ResourceHandle>()(0);
-    key = resource.name();
+    return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
   } else {
     Tensor Tstack_handle = ctx->mutable_input(0, false);
     if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@
     }
     const string& container = Tstack_handle.flat<string>()(0);
     const string& stack_name = Tstack_handle.flat<string>()(1);
-    key = strings::StrCat(container, stack_name);
+    string key = strings::StrCat(container, stack_name);
+    ResourceMgr* rm = ctx->resource_manager();
+    if (rm == nullptr) {
+      return errors::Internal("No resource manager.");
+    }
+    auto* step_container = ctx->step_container();
+    if (step_container == nullptr) {
+      return errors::Internal("No step container.");
+    }
+    TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+    return Status::OK();
   }
-  ResourceMgr* rm = ctx->resource_manager();
-  if (rm == nullptr) {
-    return errors::Internal("No resource manager.");
-  }
-  auto* step_container = ctx->step_container();
-  if (step_container == nullptr) {
-    return errors::Internal("No step container.");
-  }
-  TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
-  return Status::OK();
 }
 
 std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 2aeafa2..544dca9 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -43,7 +43,7 @@
     for (int64 i = 0; i < input.size(); ++i) {
       StringPiece entry(input(i));
       str_util::RemoveWhitespaceContext(&entry);
-      output(i) = std::string(entry);
+      output(i) = string(entry);
     }
   }
 };
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e4591..07f1d6e 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@
 limitations under the License.
 ==============================================================================*/
 
+#include <cstddef>
+#include <cstdlib>
 #include <string>
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/kernels/bounds_check.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
 #include "tensorflow/core/util/bcast.h"
 
 namespace tensorflow {
@@ -64,26 +68,28 @@
         const T len =
             tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
-          string in = input(i);
+          StringPiece in(input(i));
           OP_REQUIRES(
-              context, FastBoundsCheck(pos, in.size() + 1),
+              context, FastBoundsCheck(std::abs(pos), in.size() + 1),
               errors::InvalidArgument("pos ", pos, " out of range for string",
                                       "b'", in, "' at index ", i));
-          output(i) = in.substr(pos, len);
+          StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+          output(i).assign(sub_in.data(), sub_in.size());
         }
       } else {
         // Perform Op element-wise with tensor pos/len
         auto pos_flat = pos_tensor.flat<T>();
         auto len_flat = len_tensor.flat<T>();
         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
-          string in = input(i);
+          StringPiece in(input(i));
           const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
           const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
           OP_REQUIRES(
-              context, FastBoundsCheck(pos, in.size() + 1),
+              context, FastBoundsCheck(std::abs(pos), in.size() + 1),
               errors::InvalidArgument("pos ", pos, " out of range for string",
                                       "b'", in, "' at index ", i));
-          output(i) = in.substr(pos, len);
+          StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+          output(i).assign(sub_in.data(), sub_in.size());
         }
       }
     } else {
@@ -142,14 +148,16 @@
 
           // Iterate through broadcasted tensors and perform substr
           for (int i = 0; i < output_shape.dim_size(0); ++i) {
-            string in = input_bcast(i);
+            StringPiece in(input_bcast(i));
             const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
             const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
             OP_REQUIRES(
-                context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+                context,
+                FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
                 errors::InvalidArgument("pos ", pos, " out of range for string",
                                         "b'", in, "' at index ", i));
-            output(i) = in.substr(pos, len);
+            StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+            output(i).assign(sub_in.data(), sub_in.size());
           }
           break;
         }
@@ -192,16 +200,18 @@
           // Iterate through broadcasted tensors and perform substr
           for (int i = 0; i < output_shape.dim_size(0); ++i) {
             for (int j = 0; j < output_shape.dim_size(1); ++j) {
-              string in = input_bcast(i, j);
+              StringPiece in(input_bcast(i, j));
               const T pos =
                   tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
               const T len =
                   tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
-              OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
-                          errors::InvalidArgument(
-                              "pos ", pos, " out of range for ", "string b'",
-                              in, "' at index (", i, ", ", j, ")"));
-              output(i, j) = in.substr(pos, len);
+              OP_REQUIRES(
+                  context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+                  errors::InvalidArgument("pos ", pos, " out of range for ",
+                                          "string b'", in, "' at index (", i,
+                                          ", ", j, ")"));
+              StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+              output(i, j).assign(sub_in.data(), sub_in.size());
             }
           }
           break;
@@ -213,6 +223,16 @@
       }
     }
   }
+
+ private:
+  // This adjusts the requested position. Note it does not perform any bound
+  // checks.
+  T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+    if (pos_requested < 0) {
+      return s.size() + pos_requested;
+    }
+    return pos_requested;
+  }
 };
 
 #define REGISTER_SUBSTR(type)                                      \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000..2e07050
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+    "**TensorFlow** is an open source software library for numerical "
+    "computation using data flow graphs.",
+    "The graph nodes represent mathematical operations, while the graph edges "
+    "represent the multidimensional data arrays (tensors) that flow between "
+    "them.",
+    "This flexible architecture enables you to deploy computation to one or "
+    "more CPUs or GPUs in a desktop, server, or mobile device without "
+    "rewriting code.",
+    "TensorFlow also includes "
+    "[TensorBoard](https://www.tensorflow.org/guide/"
+    "summaries_and_tensorboard), a data visualization toolkit.",
+    "TensorFlow was originally developed by researchers and engineers working "
+    "on the Google Brain team within Google's Machine Intelligence Research "
+    "organization for the purposes of conducting machine learning and deep "
+    "neural networks research.",
+    "The system is general enough to be applicable in a wide variety of other "
+    "domains, as well.",
+    "TensorFlow provides stable Python API and C APIs as well as without API "
+    "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+    "Swift."};
+
+Tensor GetTestTensor(int batch) {
+  const int sz = TF_ARRAYSIZE(lines);
+  Tensor t(DT_STRING, {batch});
+  auto s = t.flat<string>();
+  for (int i = 0; i < batch; ++i) {
+    s(i) = lines[i % sz];
+  }
+  return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+  Graph* g = new Graph(OpRegistry::Global());
+  Tensor position(DT_INT32, TensorShape({}));
+  position.flat<int32>().setConstant(pos);
+  Tensor length(DT_INT32, TensorShape({}));
+  length.flat<int32>().setConstant(len);
+
+  TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+                  .Input(test::graph::Constant(g, input))
+                  .Input(test::graph::Constant(g, position))
+                  .Input(test::graph::Constant(g, length))
+                  .Finalize(g, nullptr /* node */));
+  return g;
+}
+
+void BM_Substr(int iters, int batch_size) {
+  testing::StopTiming();
+  testing::ItemsProcessed(static_cast<int64>(iters));
+  testing::UseRealTime();
+  Tensor input = GetTestTensor(batch_size);
+  Graph* g = SetupSubstrGraph(input, 3, 30);
+  testing::StartTiming();
+  test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
+    256);
+
+}  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 632b65e..fe93b91 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -290,14 +290,14 @@
       }
     } else {
       container = "_tensor_arrays";
-      auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+      const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
       if (StringPiece(resource.name()).substr(0, container.size()) !=
           container) {
         return errors::InvalidArgument("Wrong input container. ",
                                        resource.name());
       }
       tensor_array_name =
-          std::string(StringPiece(resource.name()).substr(container.size()));
+          string(StringPiece(resource.name()).substr(container.size()));
     }
 
     auto output_handle = tensor_array_output_handle->flat<string>();
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 9dedb61..ca341e5 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -35,8 +35,9 @@
  public:
   TypedConditionalAccumulatorBase(const DataType& dtype,
                                   const PartialTensorShape& shape,
-                                  const string& name)
-      : ConditionalAccumulatorBase(dtype, shape, name) {}
+                                  const string& name,
+                                  const string& reduction_type)
+      : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
 
   /**
    * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index ed2bf3e..1bf46b5 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -134,7 +134,7 @@
                     "Contents tensor must be scalar, but had shape: ",
                     contents_input->shape().DebugString()));
     const string& filename = filename_input->scalar<string>()();
-    const string dir = std::string(io::Dirname(filename));
+    const string dir(io::Dirname(filename));
     if (!context->env()->FileExists(dir).ok()) {
       OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
     }
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index 982901a..d5cbe6c 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -136,11 +136,9 @@
         ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
       });
 }
-// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and
-// `^^key:value^^` in a follow-on CL.
 // LINT.IfChange
 inline string FormatColocationNodeForError(const string& name) {
-  return strings::StrCat("^^colocation_node:", name, "^^");
+  return strings::StrCat("{{colocation_node ", name, "}}");
 }
 // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
 template <typename T>
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index c18dc9a..2d622dc 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -13,674 +13,19 @@
 limitations under the License.
 ==============================================================================*/
 
-// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
-// for sequences of length <= N are provided inline without requiring
-// any heap allocation.  Typically N is very small (e.g., 4) so that
-// sequences that are expected to be short do not require allocations.
-//
-// Only some of the std::vector<> operations are currently implemented.
-// Other operations may be added as needed to facilitate migrating
-// code that uses std::vector<> to InlinedVector<>.
-//
-// NOTE: If you want an inlined version to replace use of a
-// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
-// in util/bitmap/inlined_bitvector.h
-//
-// TODO(billydonahue): change size_t to size_type where appropriate.
-
 #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
 #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
 
-#include <stddef.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <algorithm>
-#include <cstddef>
-#include <iterator>
-#include <memory>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/byte_order.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
+#include "absl/container/inlined_vector.h"
+// TODO(kramerb): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
+#include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
-#include <initializer_list>  // NOLINT(build/include_order)
-
 namespace tensorflow {
 namespace gtl {
 
-template <typename T, int N>
-class InlinedVector {
- public:
-  typedef T value_type;
-  typedef T* pointer;
-  typedef const T* const_pointer;
-  typedef T& reference;
-  typedef const T& const_reference;
-  typedef size_t size_type;
-  typedef std::ptrdiff_t difference_type;
-  typedef pointer iterator;
-  typedef const_pointer const_iterator;
-
-  // Create an empty vector
-  InlinedVector();
-
-  // Create a vector with n copies of value_type().
-  explicit InlinedVector(size_t n);
-
-  // Create a vector with n copies of elem
-  InlinedVector(size_t n, const value_type& elem);
-
-  // Create and initialize with the elements [range_start .. range_end).
-  // The unused enable_if argument restricts this constructor so that it is
-  // elided when value_type is an integral type.  This prevents ambiguous
-  // interpretation between a call to this constructor with two integral
-  // arguments and a call to the preceding (n, elem) constructor.
-  template <typename InputIterator>
-  InlinedVector(
-      InputIterator range_start, InputIterator range_end,
-      typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
-          NULL) {
-    InitRep();
-    AppendRange(range_start, range_end);
-  }
-
-  InlinedVector(std::initializer_list<value_type> init) {
-    InitRep();
-    AppendRange(init.begin(), init.end());
-  }
-
-  InlinedVector(const InlinedVector& v);
-
-  ~InlinedVector() { clear(); }
-
-  InlinedVector& operator=(const InlinedVector& v) {
-    // Optimized to avoid reallocation.
-    // Prefer reassignment to copy construction for elements.
-    const size_t s = size();
-    const size_t vs = v.size();
-    if (s < vs) {  // grow
-      reserve(vs);
-      if (s) std::copy(v.begin(), v.begin() + s, begin());
-      std::copy(v.begin() + s, v.end(), std::back_inserter(*this));
-    } else {  // maybe shrink
-      erase(begin() + vs, end());
-      std::copy(v.begin(), v.end(), begin());
-    }
-    return *this;
-  }
-
-  size_t size() const { return size_internal(); }
-
-  bool empty() const { return (size() == 0); }
-
-  // Return number of elements that can be stored in vector
-  // without requiring a reallocation of underlying memory
-  size_t capacity() const {
-    if (is_inline()) {
-      return kFit;
-    } else {
-      return static_cast<size_t>(1) << u_.data[kSize - 2];
-    }
-  }
-
-  // Return a pointer to the underlying array.
-  // Only result[0,size()-1] are defined.
-  pointer data() {
-    if (is_inline()) {
-      return reinterpret_cast<T*>(u_.data);
-    } else {
-      return outofline_pointer();
-    }
-  }
-  const_pointer data() const {
-    return const_cast<InlinedVector<T, N>*>(this)->data();
-  }
-
-  // Remove all elements
-  void clear() {
-    DiscardStorage();
-    u_.data[kSize - 1] = 0;
-  }
-
-  // Return the ith element
-  // REQUIRES: 0 <= i < size()
-  const value_type& at(size_t i) const {
-    DCHECK_LT(i, size());
-    return data()[i];
-  }
-  const value_type& operator[](size_t i) const {
-    DCHECK_LT(i, size());
-    return data()[i];
-  }
-
-  // Return a non-const reference to the ith element
-  // REQUIRES: 0 <= i < size()
-  value_type& at(size_t i) {
-    DCHECK_LT(i, size());
-    return data()[i];
-  }
-  value_type& operator[](size_t i) {
-    DCHECK_LT(i, size());
-    return data()[i];
-  }
-
-  value_type& back() {
-    DCHECK(!empty());
-    return at(size() - 1);
-  }
-
-  const value_type& back() const {
-    DCHECK(!empty());
-    return at(size() - 1);
-  }
-
-  value_type& front() {
-    DCHECK(!empty());
-    return at(0);
-  }
-
-  const value_type& front() const {
-    DCHECK(!empty());
-    return at(0);
-  }
-
-  // Append a T constructed with args to the vector.
-  // Increases size() by one.
-  // Amortized complexity: O(1)
-  // Worst-case complexity: O(size())
-  template <typename... Args>
-  void emplace_back(Args&&... args) {
-    size_t s = size();
-    DCHECK_LE(s, capacity());
-    if (s < capacity()) {
-      new (data() + s) T(std::forward<Args>(args)...);
-      set_size_internal(s + 1);
-    } else {
-      EmplaceBackSlow(std::forward<Args>(args)...);
-    }
-  }
-
-  // Append t to the vector.
-  // Increases size() by one.
-  // Amortized complexity: O(1)
-  // Worst-case complexity: O(size())
-  void push_back(const value_type& t) { emplace_back(t); }
-  void push_back(value_type&& t) { emplace_back(std::move(t)); }
-
-  inline void pop_back() {
-    DCHECK(!empty());
-    const size_t s = size();
-    Destroy(data() + s - 1, 1);
-    set_size_internal(s - 1);
-  }
-
-  // Resizes the vector to contain "n" elements.
-  // If "n" is smaller than the initial size, extra elements are destroyed.
-  // If "n" is larger than the initial size, enough copies of "elem"
-  // are appended to increase the size to "n". If "elem" is omitted,
-  // new elements are value-initialized.
-  void resize(size_t n) { Resize<ValueInit>(n, nullptr); }
-  void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); }
-
-  iterator begin() { return data(); }
-  const_iterator begin() const { return data(); }
-
-  iterator end() { return data() + size(); }
-  const_iterator end() const { return data() + size(); }
-
-  iterator insert(iterator pos, const value_type& v);
-
-  iterator erase(iterator pos) {
-    DCHECK_LT(pos, end());
-    DCHECK_GE(pos, begin());
-    std::copy(pos + 1, end(), pos);
-    pop_back();
-    return pos;
-  }
-
-  iterator erase(iterator first, iterator last);
-
-  // Enlarges the underlying representation so it can hold at least
-  // "n" elements without reallocation.
-  // Does not change size() or the actual contents of the vector.
-  void reserve(size_t n) {
-    if (n > capacity()) {
-      // Make room for new elements
-      Grow<Move>(n);
-    }
-  }
-
-  // Swap the contents of *this with other.
-  // REQUIRES: value_type is swappable and copyable.
-  void swap(InlinedVector& other);
-
- private:
-  // Representation can either be inlined or out-of-line.
-  // In either case, at least sizeof(void*) + 8 bytes are available.
-  //
-  // Inlined:
-  //   Last byte holds the length.
-  //   First (length*sizeof(T)) bytes stores the elements.
-  // Outlined:
-  //   Last byte holds kSentinel.
-  //   Second-last byte holds lg(capacity)
-  //   Preceding 6 bytes hold size.
-  //   First sizeof(T*) bytes hold pointer.
-
-  // Compute rep size.
-  static const size_t kSizeUnaligned = N * sizeof(T) + 1;  // Room for tag
-  static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16;  // Align
-
-  // See how many fit T we can fit inside kSize, but no more than 254
-  // since 255 is used as sentinel tag for out-of-line allocation.
-  static const unsigned int kSentinel = 255;
-  static const size_t kFit1 = (kSize - 1) / sizeof(T);
-  static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1;
-
-  union {
-    unsigned char data[kSize];
-    // Force data to be aligned enough for a pointer.
-    T* unused_aligner;
-  } u_;
-
-  inline void InitRep() { u_.data[kSize - 1] = 0; }
-  inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; }
-
-  inline T* outofline_pointer() const {
-    T* ptr;
-    memcpy(&ptr, &u_.data[0], sizeof(ptr));
-    return ptr;
-  }
-
-  inline void set_outofline_pointer(T* p) {
-    memcpy(&u_.data[0], &p, sizeof(p));
-  }
-
-  inline uint64_t outofline_word() const {
-    uint64_t word;
-    memcpy(&word, &u_.data[kSize - 8], sizeof(word));
-    return word;
-  }
-
-  inline void set_outofline_word(uint64_t w) {
-    memcpy(&u_.data[kSize - 8], &w, sizeof(w));
-  }
-
-  inline size_t size_internal() const {
-    uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]);
-    if (s != kSentinel) {
-      return static_cast<size_t>(s);
-    } else {
-      const uint64_t word = outofline_word();
-      if (port::kLittleEndian) {
-        // The sentinel and capacity bits are most-significant bits in word.
-        return static_cast<size_t>(word & 0xffffffffffffull);
-      } else {
-        // The sentinel and capacity bits are least-significant bits in word.
-        return static_cast<size_t>(word >> 16);
-      }
-    }
-  }
-
-  void set_size_internal(size_t n) {
-    if (is_inline()) {
-      DCHECK_LT(n, kSentinel);
-      u_.data[kSize - 1] = static_cast<unsigned char>(n);
-    } else {
-      uint64_t word;
-      if (port::kLittleEndian) {
-        // The sentinel and capacity bits are most-significant bits in word.
-        word = (static_cast<uint64_t>(n) |
-                (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) |
-                (static_cast<uint64_t>(kSentinel) << 56));
-      } else {
-        // The sentinel and capacity bits are least-significant bits in word.
-        word = ((static_cast<uint64_t>(n) << 16) |
-                (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) |
-                (static_cast<uint64_t>(kSentinel)));
-      }
-      set_outofline_word(word);
-      DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n;
-    }
-  }
-
-  void DiscardStorage() {
-    T* base = data();
-    size_t n = size();
-    Destroy(base, n);
-    if (!is_inline()) {
-      port::Free(base);
-    }
-  }
-
-  template <typename... Args>
-  void EmplaceBackSlow(Args&&... args) {
-    const size_t s = size();
-    DCHECK_EQ(s, capacity());
-    Grow<Move, Construct>(s + 1, std::forward<Args>(args)...);
-    set_size_internal(s + 1);
-  }
-
-  // Movers for Grow
-  // Does nothing.
-  static void Nop(T* src, size_t n, T* dst) {}
-
-  // Moves srcs[0,n-1] contents to dst[0,n-1].
-  static void Move(T* src, size_t n, T* dst) {
-    for (size_t i = 0; i < n; i++) {
-      new (dst + i) T(std::move(*(src + i)));
-    }
-  }
-
-  // Initializers for Resize.
-  // Initializes dst[0,n-1] with empty constructor.
-  static void ValueInit(const T*, size_t n, T* dst) {
-    for (size_t i = 0; i < n; i++) {
-      new (dst + i) T();
-    }
-  }
-
-  // Initializes dst[0,n-1] with copies of *src.
-  static void Fill(const T* src, size_t n, T* dst) {
-    for (size_t i = 0; i < n; i++) {
-      new (dst + i) T(*src);
-    }
-  }
-
-  void Destroy(T* src, int n) {
-    if (!std::is_trivially_destructible<T>::value) {
-      for (int i = 0; i < n; i++) {
-        (src + i)->~T();
-      }
-    }
-  }
-
-  // Initialization methods for Grow.
-  // 1) Leave uninitialized memory.
-  struct Uninitialized {
-    void operator()(T*) const {}
-  };
-  // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
-  struct Construct {
-    template <class... Args>
-    void operator()(T* dst, Args&&... args) const {
-      new (dst) T(std::forward<Args>(args)...);
-    }
-  };
-
-  // Grow so that capacity >= n.  Uses Mover to move existing elements
-  // to new buffer, and possibly initialize the new element according
-  // to InitType.
-  // We pass the InitType and Mover as template arguments so that
-  // this code compiles even if T does not support copying or default
-  // construction.
-  template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized,
-            class... Args>
-  void Grow(size_t n, Args&&... args) {
-    size_t s = size();
-    DCHECK_LE(s, capacity());
-
-    // Compute new capacity by repeatedly doubling current capacity
-    size_t target = 1;
-    size_t target_lg = 0;
-    while (target < kFit || target < n) {
-      // TODO(psrc): Check and avoid overflow?
-      target_lg++;
-      target <<= 1;
-    }
-
-    T* src = data();
-    T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
-
-    // Need to copy elem before discarding src since it might alias src.
-    InitType{}(dst + s, std::forward<Args>(args)...);
-    Mover(src, s, dst);
-    DiscardStorage();
-
-    u_.data[kSize - 1] = kSentinel;
-    u_.data[kSize - 2] = static_cast<unsigned char>(target_lg);
-    set_size_internal(s);
-    DCHECK_EQ(capacity(), target);
-    set_outofline_pointer(dst);
-  }
-
-  // Resize to size n.  Any new elements are initialized by passing
-  // elem and the destination to Initializer.  We pass the Initializer
-  // as a template argument so that this code compiles even if T does
-  // not support copying.
-  template <void(Initializer)(const T*, size_t, T*)>
-  void Resize(size_t n, const T* elem) {
-    size_t s = size();
-    if (n <= s) {
-      Destroy(data() + n, s - n);
-      set_size_internal(n);
-      return;
-    }
-    reserve(n);
-    DCHECK_GE(capacity(), n);
-    set_size_internal(n);
-    Initializer(elem, n - s, data() + s);
-  }
-
-  template <typename Iter>
-  void AppendRange(Iter first, Iter last, std::input_iterator_tag);
-
-  // Faster path for forward iterators.
-  template <typename Iter>
-  void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
-
-  template <typename Iter>
-  void AppendRange(Iter first, Iter last);
-};
-
-// Provide linkage for constants.
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSizeUnaligned;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSize;
-template <typename T, int N>
-const unsigned int InlinedVector<T, N>::kSentinel;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit1;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit;
-
-template <typename T, int N>
-inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) {
-  a.swap(b);
-}
-
-template <typename T, int N>
-inline bool operator==(const InlinedVector<T, N>& a,
-                       const InlinedVector<T, N>& b) {
-  return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
-}
-
-template <typename T, int N>
-inline bool operator!=(const InlinedVector<T, N>& a,
-                       const InlinedVector<T, N>& b) {
-  return !(a == b);
-}
-
-template <typename T, int N>
-inline bool operator<(const InlinedVector<T, N>& a,
-                      const InlinedVector<T, N>& b) {
-  return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
-}
-
-template <typename T, int N>
-inline bool operator>(const InlinedVector<T, N>& a,
-                      const InlinedVector<T, N>& b) {
-  return b < a;
-}
-
-template <typename T, int N>
-inline bool operator<=(const InlinedVector<T, N>& a,
-                       const InlinedVector<T, N>& b) {
-  return !(b < a);
-}
-
-template <typename T, int N>
-inline bool operator>=(const InlinedVector<T, N>& a,
-                       const InlinedVector<T, N>& b) {
-  return !(a < b);
-}
-
-// ========================================
-// Implementation
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector() {
-  InitRep();
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n) {
-  InitRep();
-  if (n > capacity()) {
-    Grow<Nop>(n);  // Must use Nop in case T is not copyable
-  }
-  set_size_internal(n);
-  ValueInit(nullptr, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) {
-  InitRep();
-  if (n > capacity()) {
-    Grow<Nop>(n);  // Can use Nop since we know we have nothing to copy
-  }
-  set_size_internal(n);
-  Fill(&elem, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) {
-  InitRep();
-  *this = v;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert(
-    iterator pos, const value_type& v) {
-  DCHECK_GE(pos, begin());
-  DCHECK_LE(pos, end());
-  if (pos == end()) {
-    push_back(v);
-    return end() - 1;
-  }
-  size_t s = size();
-  size_t idx = std::distance(begin(), pos);
-  if (s == capacity()) {
-    Grow<Move>(s + 1);
-  }
-  CHECK_LT(s, capacity());
-  pos = begin() + idx;  // Reset 'pos' into a post-enlarge iterator.
-  Fill(data() + s - 1, 1, data() + s);  // data[s] = data[s-1]
-  std::copy_backward(pos, data() + s - 1, data() + s);
-  *pos = v;
-
-  set_size_internal(s + 1);
-  return pos;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase(
-    iterator first, iterator last) {
-  DCHECK_LE(begin(), first);
-  DCHECK_LE(first, last);
-  DCHECK_LE(last, end());
-
-  size_t s = size();
-  ptrdiff_t erase_gap = std::distance(first, last);
-  std::copy(last, data() + s, first);
-  Destroy(data() + s - erase_gap, erase_gap);
-  set_size_internal(s - erase_gap);
-  return first;
-}
-
-template <typename T, int N>
-void InlinedVector<T, N>::swap(InlinedVector& other) {
-  using std::swap;  // Augment ADL with std::swap.
-  if (&other == this) {
-    return;
-  }
-
-  InlinedVector* a = this;
-  InlinedVector* b = &other;
-
-  const bool a_inline = a->is_inline();
-  const bool b_inline = b->is_inline();
-
-  if (!a_inline && !b_inline) {
-    // Just swap the top-level representations.
-    T* aptr = a->outofline_pointer();
-    T* bptr = b->outofline_pointer();
-    a->set_outofline_pointer(bptr);
-    b->set_outofline_pointer(aptr);
-
-    uint64_t aword = a->outofline_word();
-    uint64_t bword = b->outofline_word();
-    a->set_outofline_word(bword);
-    b->set_outofline_word(aword);
-    return;
-  }
-
-  // Make a the larger of the two to reduce number of cases.
-  size_t a_size = a->size();
-  size_t b_size = b->size();
-  if (a->size() < b->size()) {
-    swap(a, b);
-    swap(a_size, b_size);
-  }
-  DCHECK_GE(a_size, b_size);
-
-  if (b->capacity() < a_size) {
-    b->Grow<Move>(a_size);
-  }
-
-  // One is inline and one is not.
-  // 'a' is larger. Swap the elements up to the smaller array size.
-  std::swap_ranges(a->data(), a->data() + b_size, b->data());
-  std::uninitialized_copy(a->data() + b_size, a->data() + a_size,
-                          b->data() + b_size);
-  Destroy(a->data() + b_size, a_size - b_size);
-  a->set_size_internal(b_size);
-  b->set_size_internal(a_size);
-  DCHECK_EQ(b->size(), a_size);
-  DCHECK_EQ(a->size(), b_size);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
-                                             std::input_iterator_tag) {
-  std::copy(first, last, std::back_inserter(*this));
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
-                                             std::forward_iterator_tag) {
-  typedef typename std::iterator_traits<Iter>::difference_type Length;
-  Length length = std::distance(first, last);
-  size_t s = size();
-  reserve(s + length);
-  std::uninitialized_copy_n(first, length, data() + s);
-  set_size_internal(s + length);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) {
-  typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
-  AppendRange(first, last, IterTag());
-}
+using absl::InlinedVector;
 
 }  // namespace gtl
 }  // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
deleted file mode 100644
index 2721885..0000000
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ /dev/null
@@ -1,898 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-
-#include <list>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
-
-// A type that counts number of live occurrences of the type
-static int64 instances = 0;
-class Instance {
- public:
-  int value_;
-  explicit Instance(int x) : value_(x) { instances++; }
-  Instance(const Instance& x) : value_(x.value_) { instances++; }
-  ~Instance() { instances--; }
-
-  friend inline void swap(Instance& a, Instance& b) {
-    using std::swap;
-    swap(a.value_, b.value_);
-  }
-
-  friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
-    return o << "[value:" << v.value_ << "]";
-  }
-};
-
-typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
-
-// A simple reference counted class to make sure that the proper elements are
-// destroyed in the erase(begin, end) test.
-class RefCounted {
- public:
-  RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
-
-  RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
-    VLOG(5) << "[RefCounted: copy"
-            << " from count @" << v.count_ << "]";
-    Ref();
-  }
-
-  ~RefCounted() {
-    Unref();
-    count_ = nullptr;
-  }
-
-  friend void swap(RefCounted& a, RefCounted& b) {
-    using std::swap;
-    swap(a.value_, b.value_);
-    swap(a.count_, b.count_);
-  }
-
-  RefCounted& operator=(RefCounted v) {
-    using std::swap;
-    swap(*this, v);
-    return *this;
-  }
-
-  void Ref() const {
-    CHECK(count_ != nullptr);
-    ++(*count_);
-    VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
-  }
-
-  void Unref() const {
-    --(*count_);
-    CHECK_GE(*count_, 0);
-    VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
-  }
-
-  int count() const { return *count_; }
-
-  friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
-    return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
-  }
-
-  int value_;
-  int* count_;
-};
-
-typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
-
-// A class with a vtable pointer
-class Dynamic {
- public:
-  virtual ~Dynamic() {}
-
-  friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
-    return o << "[Dynamic]";
-  }
-};
-
-typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
-
-// Append 0..len-1 to *v
-static void Fill(IntVec* v, int len, int offset = 0) {
-  for (int i = 0; i < len; i++) {
-    v->push_back(i + offset);
-  }
-}
-
-static IntVec Fill(int len, int offset = 0) {
-  IntVec v;
-  Fill(&v, len, offset);
-  return v;
-}
-
-TEST(IntVec, SimpleOps) {
-  for (int len = 0; len < 20; len++) {
-    IntVec v;
-    const IntVec& cv = v;  // const alias
-
-    Fill(&v, len);
-    EXPECT_EQ(len, v.size());
-    EXPECT_LE(len, v.capacity());
-
-    for (int i = 0; i < len; i++) {
-      EXPECT_EQ(i, v[i]);
-    }
-    EXPECT_EQ(v.begin(), v.data());
-    EXPECT_EQ(cv.begin(), cv.data());
-
-    int counter = 0;
-    for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
-      EXPECT_EQ(counter, *iter);
-      counter++;
-    }
-    EXPECT_EQ(counter, len);
-
-    counter = 0;
-    for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
-      EXPECT_EQ(counter, *iter);
-      counter++;
-    }
-    EXPECT_EQ(counter, len);
-
-    if (len > 0) {
-      EXPECT_EQ(0, v.front());
-      EXPECT_EQ(len - 1, v.back());
-      v.pop_back();
-      EXPECT_EQ(len - 1, v.size());
-      for (size_t i = 0; i < v.size(); ++i) {
-        EXPECT_EQ(i, v[i]);
-      }
-    }
-  }
-}
-
-TEST(IntVec, Erase) {
-  for (int len = 1; len < 20; len++) {
-    for (int i = 0; i < len; ++i) {
-      IntVec v;
-      Fill(&v, len);
-      v.erase(v.begin() + i);
-      EXPECT_EQ(len - 1, v.size());
-      for (int j = 0; j < i; ++j) {
-        EXPECT_EQ(j, v[j]);
-      }
-      for (int j = i; j < len - 1; ++j) {
-        EXPECT_EQ(j + 1, v[j]);
-      }
-    }
-  }
-}
-
-// At the end of this test loop, the elements between [erase_begin, erase_end)
-// should have reference counts == 0, and all others elements should have
-// reference counts == 1.
-TEST(RefCountedVec, EraseBeginEnd) {
-  for (int len = 1; len < 20; ++len) {
-    for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
-      for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
-        std::vector<int> counts(len, 0);
-        RefCountedVec v;
-        for (int i = 0; i < len; ++i) {
-          v.push_back(RefCounted(i, &counts[i]));
-        }
-
-        int erase_len = erase_end - erase_begin;
-
-        v.erase(v.begin() + erase_begin, v.begin() + erase_end);
-
-        EXPECT_EQ(len - erase_len, v.size());
-
-        // Check the elements before the first element erased.
-        for (int i = 0; i < erase_begin; ++i) {
-          EXPECT_EQ(i, v[i].value_);
-        }
-
-        // Check the elements after the first element erased.
-        for (size_t i = erase_begin; i < v.size(); ++i) {
-          EXPECT_EQ(i + erase_len, v[i].value_);
-        }
-
-        // Check that the elements at the beginning are preserved.
-        for (int i = 0; i < erase_begin; ++i) {
-          EXPECT_EQ(1, counts[i]);
-        }
-
-        // Check that the erased elements are destroyed
-        for (int i = erase_begin; i < erase_end; ++i) {
-          EXPECT_EQ(0, counts[i]);
-        }
-
-        // Check that the elements at the end are preserved.
-        for (int i = erase_end; i < len; ++i) {
-          EXPECT_EQ(1, counts[i]);
-        }
-      }
-    }
-  }
-}
-
-struct NoDefaultCtor {
-  explicit NoDefaultCtor(int) {}
-};
-struct NoCopy {
-  NoCopy() {}
-  NoCopy(const NoCopy&) = delete;
-};
-struct NoAssign {
-  NoAssign() {}
-  NoAssign& operator=(const NoAssign&) = delete;
-};
-struct MoveOnly {
-  MoveOnly() {}
-  MoveOnly(MoveOnly&&) = default;
-  MoveOnly& operator=(MoveOnly&&) = default;
-};
-TEST(InlinedVectorTest, NoDefaultCtor) {
-  tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
-  (void)v;
-}
-TEST(InlinedVectorTest, NoCopy) {
-  tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
-  (void)v;
-}
-TEST(InlinedVectorTest, NoAssign) {
-  tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
-  (void)v;
-}
-TEST(InlinedVectorTest, MoveOnly) {
-  gtl::InlinedVector<MoveOnly, 2> v;
-  v.push_back(MoveOnly{});
-  v.push_back(MoveOnly{});
-  v.push_back(MoveOnly{});
-}
-
-TEST(IntVec, Insert) {
-  for (int len = 0; len < 20; len++) {
-    for (int pos = 0; pos <= len; pos++) {
-      IntVec v;
-      Fill(&v, len);
-      v.insert(v.begin() + pos, 9999);
-      EXPECT_EQ(v.size(), len + 1);
-      for (int i = 0; i < pos; i++) {
-        EXPECT_EQ(v[i], i);
-      }
-      EXPECT_EQ(v[pos], 9999);
-      for (size_t i = pos + 1; i < v.size(); i++) {
-        EXPECT_EQ(v[i], i - 1);
-      }
-    }
-  }
-}
-
-TEST(RefCountedVec, InsertConstructorDestructor) {
-  // Make sure the proper construction/destruction happen during insert
-  // operations.
-  for (int len = 0; len < 20; len++) {
-    SCOPED_TRACE(len);
-    for (int pos = 0; pos <= len; pos++) {
-      SCOPED_TRACE(pos);
-      std::vector<int> counts(len, 0);
-      int inserted_count = 0;
-      RefCountedVec v;
-      for (int i = 0; i < len; ++i) {
-        SCOPED_TRACE(i);
-        v.push_back(RefCounted(i, &counts[i]));
-      }
-
-      for (auto elem : counts) {
-        EXPECT_EQ(1, elem);
-      }
-
-      RefCounted insert_element(9999, &inserted_count);
-      EXPECT_EQ(1, inserted_count);
-      v.insert(v.begin() + pos, insert_element);
-      EXPECT_EQ(2, inserted_count);
-      // Check that the elements at the end are preserved.
-      for (auto elem : counts) {
-        EXPECT_EQ(1, elem);
-      }
-      EXPECT_EQ(2, inserted_count);
-    }
-  }
-}
-
-TEST(IntVec, Resize) {
-  for (int len = 0; len < 20; len++) {
-    IntVec v;
-    Fill(&v, len);
-
-    // Try resizing up and down by k elements
-    static const int kResizeElem = 1000000;
-    for (int k = 0; k < 10; k++) {
-      // Enlarging resize
-      v.resize(len + k, kResizeElem);
-      EXPECT_EQ(len + k, v.size());
-      EXPECT_LE(len + k, v.capacity());
-      for (int i = 0; i < len + k; i++) {
-        if (i < len) {
-          EXPECT_EQ(i, v[i]);
-        } else {
-          EXPECT_EQ(kResizeElem, v[i]);
-        }
-      }
-
-      // Shrinking resize
-      v.resize(len, kResizeElem);
-      EXPECT_EQ(len, v.size());
-      EXPECT_LE(len, v.capacity());
-      for (int i = 0; i < len; i++) {
-        EXPECT_EQ(i, v[i]);
-      }
-    }
-  }
-}
-
-TEST(IntVec, InitWithLength) {
-  for (int len = 0; len < 20; len++) {
-    IntVec v(len, 7);
-    EXPECT_EQ(len, v.size());
-    EXPECT_LE(len, v.capacity());
-    for (int i = 0; i < len; i++) {
-      EXPECT_EQ(7, v[i]);
-    }
-  }
-}
-
-TEST(IntVec, CopyConstructorAndAssignment) {
-  for (int len = 0; len < 20; len++) {
-    IntVec v;
-    Fill(&v, len);
-    EXPECT_EQ(len, v.size());
-    EXPECT_LE(len, v.capacity());
-
-    IntVec v2(v);
-    EXPECT_EQ(v, v2);
-
-    for (int start_len = 0; start_len < 20; start_len++) {
-      IntVec v3;
-      Fill(&v3, start_len, 99);  // Add dummy elements that should go away
-      v3 = v;
-      EXPECT_EQ(v, v3);
-    }
-  }
-}
-
-TEST(OverheadTest, Storage) {
-  // Check for size overhead.
-  using tensorflow::gtl::InlinedVector;
-  EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>));
-  EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>));
-  EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>));
-  EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>));
-
-  EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>));
-  EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>));
-  EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>));
-  EXPECT_EQ(2 * sizeof(char*),
-            sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>));
-  EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>));
-}
-
-TEST(IntVec, Clear) {
-  for (int len = 0; len < 20; len++) {
-    SCOPED_TRACE(len);
-    IntVec v;
-    Fill(&v, len);
-    v.clear();
-    EXPECT_EQ(0, v.size());
-    EXPECT_EQ(v.begin(), v.end());
-  }
-}
-
-TEST(IntVec, Reserve) {
-  for (size_t len = 0; len < 20; len++) {
-    IntVec v;
-    Fill(&v, len);
-
-    for (size_t newlen = 0; newlen < 100; newlen++) {
-      const int* start_rep = v.data();
-      v.reserve(newlen);
-      const int* final_rep = v.data();
-      if (newlen <= len) {
-        EXPECT_EQ(start_rep, final_rep);
-      }
-      EXPECT_LE(newlen, v.capacity());
-
-      // Filling up to newlen should not change rep
-      while (v.size() < newlen) {
-        v.push_back(0);
-      }
-      EXPECT_EQ(final_rep, v.data());
-    }
-  }
-}
-
-template <typename T>
-static std::vector<typename T::value_type> Vec(const T& src) {
-  std::vector<typename T::value_type> result;
-  for (const auto& elem : src) {
-    result.push_back(elem);
-  }
-  return result;
-}
-
-TEST(IntVec, SelfRefPushBack) {
-  std::vector<string> std_v;
-  tensorflow::gtl::InlinedVector<string, 4> v;
-  const string s = "A quite long string to ensure heap.";
-  std_v.push_back(s);
-  v.push_back(s);
-  for (int i = 0; i < 20; ++i) {
-    EXPECT_EQ(std_v, Vec(v));
-
-    v.push_back(v.back());
-    std_v.push_back(std_v.back());
-  }
-  EXPECT_EQ(std_v, Vec(v));
-}
-
-TEST(IntVec, SelfRefPushBackWithMove) {
-  std::vector<string> std_v;
-  gtl::InlinedVector<string, 4> v;
-  const string s = "A quite long string to ensure heap.";
-  std_v.push_back(s);
-  v.push_back(s);
-  for (int i = 0; i < 20; ++i) {
-    EXPECT_EQ(v.back(), std_v.back());
-
-    v.push_back(std::move(v.back()));
-    std_v.push_back(std::move(std_v.back()));
-  }
-  EXPECT_EQ(v.back(), std_v.back());
-}
-
-TEST(IntVec, Swap) {
-  for (int l1 = 0; l1 < 20; l1++) {
-    SCOPED_TRACE(l1);
-    for (int l2 = 0; l2 < 20; l2++) {
-      SCOPED_TRACE(l2);
-      IntVec a = Fill(l1, 0);
-      IntVec b = Fill(l2, 100);
-      {
-        using std::swap;
-        swap(a, b);
-      }
-      EXPECT_EQ(l1, b.size());
-      EXPECT_EQ(l2, a.size());
-      for (int i = 0; i < l1; i++) {
-        SCOPED_TRACE(i);
-        EXPECT_EQ(i, b[i]);
-      }
-      for (int i = 0; i < l2; i++) {
-        SCOPED_TRACE(i);
-        EXPECT_EQ(100 + i, a[i]);
-      }
-    }
-  }
-}
-
-TEST(InstanceVec, Swap) {
-  for (int l1 = 0; l1 < 20; l1++) {
-    for (int l2 = 0; l2 < 20; l2++) {
-      InstanceVec a, b;
-      for (int i = 0; i < l1; i++) a.push_back(Instance(i));
-      for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
-      EXPECT_EQ(l1 + l2, instances);
-      {
-        using std::swap;
-        swap(a, b);
-      }
-      EXPECT_EQ(l1 + l2, instances);
-      EXPECT_EQ(l1, b.size());
-      EXPECT_EQ(l2, a.size());
-      for (int i = 0; i < l1; i++) {
-        EXPECT_EQ(i, b[i].value_);
-      }
-      for (int i = 0; i < l2; i++) {
-        EXPECT_EQ(100 + i, a[i].value_);
-      }
-    }
-  }
-}
-
-TEST(IntVec, EqualAndNotEqual) {
-  IntVec a, b;
-  EXPECT_TRUE(a == b);
-  EXPECT_FALSE(a != b);
-
-  a.push_back(3);
-  EXPECT_FALSE(a == b);
-  EXPECT_TRUE(a != b);
-
-  b.push_back(3);
-  EXPECT_TRUE(a == b);
-  EXPECT_FALSE(a != b);
-
-  b.push_back(7);
-  EXPECT_FALSE(a == b);
-  EXPECT_TRUE(a != b);
-
-  a.push_back(6);
-  EXPECT_FALSE(a == b);
-  EXPECT_TRUE(a != b);
-
-  a.clear();
-  b.clear();
-  for (int i = 0; i < 100; i++) {
-    a.push_back(i);
-    b.push_back(i);
-    EXPECT_TRUE(a == b);
-    EXPECT_FALSE(a != b);
-
-    b[i] = b[i] + 1;
-    EXPECT_FALSE(a == b);
-    EXPECT_TRUE(a != b);
-
-    b[i] = b[i] - 1;  // Back to before
-    EXPECT_TRUE(a == b);
-    EXPECT_FALSE(a != b);
-  }
-}
-
-TEST(IntVec, RelationalOps) {
-  IntVec a, b;
-  EXPECT_FALSE(a < b);
-  EXPECT_FALSE(b < a);
-  EXPECT_FALSE(a > b);
-  EXPECT_FALSE(b > a);
-  EXPECT_TRUE(a <= b);
-  EXPECT_TRUE(b <= a);
-  EXPECT_TRUE(a >= b);
-  EXPECT_TRUE(b >= a);
-  b.push_back(3);
-  EXPECT_TRUE(a < b);
-  EXPECT_FALSE(b < a);
-  EXPECT_FALSE(a > b);
-  EXPECT_TRUE(b > a);
-  EXPECT_TRUE(a <= b);
-  EXPECT_FALSE(b <= a);
-  EXPECT_FALSE(a >= b);
-  EXPECT_TRUE(b >= a);
-}
-
-TEST(InstanceVec, CountConstructorsDestructors) {
-  const int start = instances;
-  for (int len = 0; len < 20; len++) {
-    InstanceVec v;
-    for (int i = 0; i < len; i++) {
-      v.push_back(Instance(i));
-    }
-    EXPECT_EQ(start + len, instances);
-
-    {  // Copy constructor should create 'len' more instances.
-      InstanceVec v_copy(v);
-      EXPECT_EQ(start + len + len, instances);
-    }
-    EXPECT_EQ(start + len, instances);
-
-    // Enlarging resize() must construct some objects
-    v.resize(len + 10, Instance(100));
-    EXPECT_EQ(start + len + 10, instances);
-
-    // Shrinking resize() must destroy some objects
-    v.resize(len, Instance(100));
-    EXPECT_EQ(start + len, instances);
-
-    // reserve() must not increase the number of initialized objects
-    v.reserve(len + 1000);
-    EXPECT_EQ(start + len, instances);
-
-    // pop_back() and erase() must destroy one object
-    if (len > 0) {
-      v.pop_back();
-      EXPECT_EQ(start + len - 1, instances);
-      if (!v.empty()) {
-        v.erase(v.begin());
-        EXPECT_EQ(start + len - 2, instances);
-      }
-    }
-  }
-  EXPECT_EQ(start, instances);
-}
-
-TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
-  const int start = instances;
-  for (int len = 0; len < 20; len++) {
-    for (int longorshort = 0; longorshort <= 1; ++longorshort) {
-      InstanceVec longer, shorter;
-      for (int i = 0; i < len; i++) {
-        longer.push_back(Instance(i));
-        shorter.push_back(Instance(i));
-      }
-      longer.push_back(Instance(len));
-      EXPECT_EQ(start + len + len + 1, instances);
-
-      if (longorshort) {
-        shorter = longer;
-        EXPECT_EQ(start + (len + 1) + (len + 1), instances);
-      } else {
-        longer = shorter;
-        EXPECT_EQ(start + len + len, instances);
-      }
-    }
-  }
-  EXPECT_EQ(start, instances);
-}
-
-TEST(RangedConstructor, SimpleType) {
-  std::vector<int> source_v = {4, 5, 6, 7};
-  // First try to fit in inline backing
-  tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
-  tensorflow::gtl::InlinedVector<int, 4> empty4;
-  EXPECT_EQ(4, v.size());
-  EXPECT_EQ(empty4.capacity(), v.capacity());  // Must still be inline
-  EXPECT_EQ(4, v[0]);
-  EXPECT_EQ(5, v[1]);
-  EXPECT_EQ(6, v[2]);
-  EXPECT_EQ(7, v[3]);
-
-  // Now, force a re-allocate
-  tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
-                                                   source_v.end());
-  tensorflow::gtl::InlinedVector<int, 2> empty2;
-  EXPECT_EQ(4, realloc_v.size());
-  EXPECT_LT(empty2.capacity(), realloc_v.capacity());
-  EXPECT_EQ(4, realloc_v[0]);
-  EXPECT_EQ(5, realloc_v[1]);
-  EXPECT_EQ(6, realloc_v[2]);
-  EXPECT_EQ(7, realloc_v[3]);
-}
-
-TEST(RangedConstructor, ComplexType) {
-  // We also use a list here to pass a different flavor of iterator (e.g. not
-  // random-access).
-  std::list<Instance> source_v = {Instance(0)};
-
-  // First try to fit in inline backing
-  tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
-                                                source_v.end());
-  tensorflow::gtl::InlinedVector<Instance, 1> empty1;
-  EXPECT_EQ(1, v.size());
-  EXPECT_EQ(empty1.capacity(), v.capacity());  // Must still be inline
-  EXPECT_EQ(0, v[0].value_);
-
-  std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2),
-                                   Instance(3)};
-  // Now, force a re-allocate
-  tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
-                                                        source_v2.end());
-  EXPECT_EQ(4, realloc_v.size());
-  EXPECT_LT(empty1.capacity(), realloc_v.capacity());
-  EXPECT_EQ(0, realloc_v[0].value_);
-  EXPECT_EQ(1, realloc_v[1].value_);
-  EXPECT_EQ(2, realloc_v[2].value_);
-  EXPECT_EQ(3, realloc_v[3].value_);
-}
-
-TEST(RangedConstructor, ElementsAreConstructed) {
-  std::vector<string> source_v = {"cat", "dog"};
-
-  // Force expansion and re-allocation of v.  Ensures that when the vector is
-  // expanded that new elements are constructed.
-  tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
-  EXPECT_EQ("cat", v[0]);
-  EXPECT_EQ("dog", v[1]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
-  auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6};
-  EXPECT_EQ(3, vec.size());
-  EXPECT_EQ(3, vec.capacity());
-  EXPECT_EQ(4, vec[0]);
-  EXPECT_EQ(5, vec[1]);
-  EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
-  auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
-  EXPECT_EQ(3, vec.size());
-  EXPECT_LE(3, vec.capacity());
-  EXPECT_EQ(4, vec[0]);
-  EXPECT_EQ(5, vec[1]);
-  EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, DisparateTypesInList) {
-  EXPECT_EQ((std::vector<int>{-7, 8}),
-            Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
-
-  EXPECT_EQ(
-      (std::vector<string>{"foo", "bar"}),
-      Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
-  tensorflow::gtl::InlinedVector<Instance, 1> empty;
-  auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
-  EXPECT_EQ(1, vec.size());
-  EXPECT_EQ(empty.capacity(), vec.capacity());
-  EXPECT_EQ(0, vec[0].value_);
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
-  auto vec =
-      tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
-  EXPECT_EQ(2, vec.size());
-  EXPECT_LE(2, vec.capacity());
-  EXPECT_EQ(0, vec[0].value_);
-  EXPECT_EQ(1, vec[1].value_);
-}
-
-TEST(DynamicVec, DynamicVecCompiles) {
-  DynamicVec v;
-  (void)v;
-}
-
-static void BM_InlinedVectorFill(int iters, int len) {
-  for (int i = 0; i < iters; i++) {
-    IntVec v;
-    for (int j = 0; j < len; j++) {
-      v.push_back(j);
-    }
-  }
-  testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
-
-static void BM_InlinedVectorFillRange(int iters, int len) {
-  std::unique_ptr<int[]> ia(new int[len]);
-  for (int j = 0; j < len; j++) {
-    ia[j] = j;
-  }
-  for (int i = 0; i < iters; i++) {
-    IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
-  }
-  testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
-
-static void BM_StdVectorFill(int iters, int len) {
-  for (int i = 0; i < iters; i++) {
-    std::vector<int> v;
-    v.reserve(len);
-    for (int j = 0; j < len; j++) {
-      v.push_back(j);
-    }
-  }
-  testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
-
-bool StringRepresentedInline(string s) {
-  const char* chars = s.data();
-  string s1 = std::move(s);
-  return s1.data() != chars;
-}
-
-static void BM_InlinedVectorFillString(int iters, int len) {
-  string strings[4] = {"a quite long string", "another long string",
-                       "012345678901234567", "to cause allocation"};
-  for (int i = 0; i < iters; i++) {
-    gtl::InlinedVector<string, 8> v;
-    for (int j = 0; j < len; j++) {
-      v.push_back(strings[j & 3]);
-    }
-  }
-  testing::ItemsProcessed(int64{iters} * len);
-}
-BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024);
-
-static void BM_StdVectorFillString(int iters, int len) {
-  string strings[4] = {"a quite long string", "another long string",
-                       "012345678901234567", "to cause allocation"};
-  for (int i = 0; i < iters; i++) {
-    std::vector<string> v;
-    v.reserve(len);
-    for (int j = 0; j < len; j++) {
-      v.push_back(strings[j & 3]);
-    }
-  }
-  testing::ItemsProcessed(int64{iters} * len);
-  // The purpose of the benchmark is to verify that inlined vector is
-  // efficient when moving is more efficient than copying. To do so, we
-  // use strings that are larger than the small string optimization.
-  CHECK(!StringRepresentedInline(strings[0]));
-}
-BENCHMARK(BM_StdVectorFillString)->Range(0, 1024);
-
-namespace {
-struct Buffer {  // some arbitrary structure for benchmarking.
-  char* base;
-  int length;
-  int capacity;
-  void* user_data;
-};
-}  // anonymous namespace
-
-static void BM_InlinedVectorTenAssignments(int iters, int len) {
-  typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
-
-  BufferVec src;
-  src.resize(len);
-
-  iters *= 10;
-  BufferVec dst;
-  for (int i = 0; i < iters; i++) {
-    dst = src;
-  }
-}
-BENCHMARK(BM_InlinedVectorTenAssignments)
-    ->Arg(0)
-    ->Arg(1)
-    ->Arg(2)
-    ->Arg(3)
-    ->Arg(4)
-    ->Arg(20);
-
-static void BM_CreateFromInitializerList(int iters) {
-  for (; iters > 0; iters--) {
-    tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
-    (void)x[0];
-  }
-}
-BENCHMARK(BM_CreateFromInitializerList);
-
-namespace {
-
-struct LargeSwappable {
-  LargeSwappable() : d_(1024, 17) {}
-  ~LargeSwappable() {}
-  LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
-
-  friend void swap(LargeSwappable& a, LargeSwappable& b) {
-    using std::swap;
-    swap(a.d_, b.d_);
-  }
-
-  LargeSwappable& operator=(LargeSwappable o) {
-    using std::swap;
-    swap(*this, o);
-    return *this;
-  }
-
-  std::vector<int> d_;
-};
-
-}  // namespace
-
-static void BM_LargeSwappableElements(int iters, int len) {
-  typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
-  Vec a(len);
-  Vec b;
-  while (--iters >= 0) {
-    using std::swap;
-    swap(a, b);
-  }
-}
-BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index c24628b..f93ebea 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -109,9 +109,6 @@
 }
 
 Status RecordReader::ReadRecord(uint64* offset, string* record) {
-  static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
-  static const size_t kFooterSize = sizeof(uint32);
-
   // Position the input stream.
   int64 curr_pos = input_stream_->Tell();
   int64 desired_pos = static_cast<int64>(*offset);
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index c05f9e1..11af136 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -58,6 +58,14 @@
 // Note: this class is not thread safe; external synchronization required.
 class RecordReader {
  public:
+  // Format of a single record:
+  //  uint64    length
+  //  uint32    masked crc of length
+  //  byte      data[length]
+  //  uint32    masked crc of data
+  static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+  static const size_t kFooterSize = sizeof(uint32);
+
   // Create a reader that will return log records from "*file".
   // "*file" must remain live while this Reader is in use.
   explicit RecordReader(
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 6e71d23..2c6db24 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -88,10 +88,6 @@
   }
 }
 
-static uint32 MaskedCrc(const char* data, size_t n) {
-  return crc32c::Mask(crc32c::Value(data, n));
-}
-
 Status RecordWriter::WriteRecord(StringPiece data) {
   if (dest_ == nullptr) {
     return Status(::tensorflow::error::FAILED_PRECONDITION,
@@ -102,13 +98,10 @@
   //  uint32    masked crc of length
   //  byte      data[length]
   //  uint32    masked crc of data
-  char header[sizeof(uint64) + sizeof(uint32)];
-  core::EncodeFixed64(header + 0, data.size());
-  core::EncodeFixed32(header + sizeof(uint64),
-                      MaskedCrc(header, sizeof(uint64)));
-  char footer[sizeof(uint32)];
-  core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
-
+  char header[kHeaderSize];
+  char footer[kFooterSize];
+  PopulateHeader(header, data.data(), data.size());
+  PopulateFooter(footer, data.data(), data.size());
   TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
   TF_RETURN_IF_ERROR(dest_->Append(data));
   return dest_->Append(StringPiece(footer, sizeof(footer)));
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 2f6afa5..1212e1f 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -16,8 +16,10 @@
 #ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
 #define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
 
+#include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
 #if !defined(IS_SLIM_BUILD)
 #include "tensorflow/core/lib/io/zlib_compression_options.h"
 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@@ -41,12 +43,20 @@
 
 // Options specific to zlib compression.
 #if !defined(IS_SLIM_BUILD)
-  ZlibCompressionOptions zlib_options;
+  tensorflow::io::ZlibCompressionOptions zlib_options;
 #endif  // IS_SLIM_BUILD
 };
 
 class RecordWriter {
  public:
+  // Format of a single record:
+  //  uint64    length
+  //  uint32    masked crc of length
+  //  byte      data[length]
+  //  uint32    masked crc of data
+  static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+  static const size_t kFooterSize = sizeof(uint32);
+
   // Create a writer that will append data to "*dest".
   // "*dest" must be initially empty.
   // "*dest" must remain live while this Writer is in use.
@@ -72,13 +82,35 @@
   // are invalid.
   Status Close();
 
+  // Utility method to populate TFRecord headers.  Populates record-header in
+  // "header[0,kHeaderSize-1]".  The record-header is based on data[0, n-1].
+  inline static void PopulateHeader(char* header, const char* data, size_t n);
+
+  // Utility method to populate TFRecord footers.  Populates record-footer in
+  // "footer[0,kFooterSize-1]".  The record-footer is based on data[0, n-1].
+  inline static void PopulateFooter(char* footer, const char* data, size_t n);
+
  private:
   WritableFile* dest_;
   RecordWriterOptions options_;
 
+  inline static uint32 MaskedCrc(const char* data, size_t n) {
+    return crc32c::Mask(crc32c::Value(data, n));
+  }
+
   TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
 };
 
+void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) {
+  core::EncodeFixed64(header + 0, n);
+  core::EncodeFixed32(header + sizeof(uint64),
+                      MaskedCrc(header, sizeof(uint64)));
+}
+
+void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) {
+  core::EncodeFixed32(footer, MaskedCrc(data, n));
+}
+
 }  // namespace io
 }  // namespace tensorflow
 
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index da514bd..946d718 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -58,7 +58,7 @@
   Status Close() override { return Status::OK(); }
   Status Flush() override { return Status::OK(); }
   Status Sync() override { return Status::OK(); }
-  Status Append(const StringPiece& slice) override {
+  Status Append(StringPiece slice) override {
     contents_->append(slice.data(), slice.size());
     return Status::OK();
   }
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 877ac40..9cebbf4 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -98,7 +98,7 @@
   Status Flush() override { return Status::OK(); }
   Status Sync() override { return Status::OK(); }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     contents_.append(data.data(), data.size());
     return Status::OK();
   }
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 84b47c1..cba139e 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -143,7 +143,7 @@
   return Status::OK();
 }
 
-Status ZlibOutputBuffer::Append(const StringPiece& data) {
+Status ZlibOutputBuffer::Append(StringPiece data) {
   // If there is sufficient free space in z_stream_input_ to fit data we
   // add it there and return.
   // If there isn't enough space we deflate the existing contents of
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 3d86d89..ccad2fd 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -62,7 +62,7 @@
   // to file when the buffer is full.
   //
   // To immediately write contents to file call `Flush()`.
-  Status Append(const StringPiece& data) override;
+  Status Append(StringPiece data) override;
 
   // Deflates any cached input and writes all output to file.
   Status Flush() override;
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 36d939e..c536b56 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -232,6 +232,11 @@
         "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
   }
   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+  if (*channel_count < 1) {
+    return errors::InvalidArgument(
+        "Bad number of channels for WAV: Expected at least 1, but got ",
+        *channel_count);
+  }
   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
   uint32 bytes_per_second;
   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3..7c4184b 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@
 
 namespace tensorflow {
 
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
 REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
 
 REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -354,4 +358,125 @@
       return Status::OK();
     });
 
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+    .Input("quantile_stream_resource_handle: resource")
+    .Output("is_initialized: bool")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      c->set_output(0, c->Scalar());
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+    .Attr("max_elements: int = 1099511627776")  // 1 << 40
+    .Input("quantile_stream_resource_handle: resource")
+    .Input("epsilon: float")
+    .Input("num_streams: int64")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+    .Attr("num_features: int >= 0")
+    .Input("float_values: num_features * float")
+    .Input("example_weights: float")
+    .Input("epsilon: float")
+    .Output("summaries: num_features * float")
+    .SetShapeFn([](InferenceContext* c) {
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      ShapeHandle example_weights_shape;
+      TF_RETURN_IF_ERROR(
+          c->WithRank(c->input(num_features), 1, &example_weights_shape));
+      for (int i = 0; i < num_features; ++i) {
+        ShapeHandle feature_shape;
+        DimensionHandle unused_dim;
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+        TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+                                    c->Dim(example_weights_shape, 0),
+                                    &unused_dim));
+        // the columns are value, weight, min_rank, max_rank.
+        c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+      }
+      // epsilon must be a scalar.
+      ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(
+          c->WithRank(c->input(num_features + 1), 0, &unused_input));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+    .Attr("num_features: int >= 0")
+    .Input("quantile_stream_resource_handle: resource")
+    .Input("summaries: num_features * float")
+    .SetShapeFn([](InferenceContext* c) {
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      // resource handle must be a scalar.
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      // each summary must be rank 2.
+      for (int i = 1; i < num_features + 1; i++) {
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+      }
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+    .Attr("generate_quantiles: bool = False")
+    .Input("quantile_stream_resource_handle: resource")
+    .Input("num_buckets: int64")
+    .SetShapeFn([](InferenceContext* c) {
+      // All the inputs are scalars.
+      shape_inference::ShapeHandle unused_input;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+    .Attr("num_features: int >= 0")
+    .Input("quantile_stream_resource_handle: resource")
+    .Output("bucket_boundaries: num_features * float")
+    .SetShapeFn([](InferenceContext* c) {
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      shape_inference::ShapeHandle unused_input;
+      // resource handle must be a scalar.
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+      for (int i = 0; i < num_features; i++) {
+        c->set_output(i, c->Vector(c->UnknownDim()));
+      }
+      return Status::OK();
+    });
+
+REGISTER_OP("BoostedTreesBucketize")
+    .Attr("num_features: int >= 0")
+    .Input("float_values: num_features * float")
+    .Input("bucket_boundaries: num_features * float")
+    .Output("buckets: num_features * int32")
+    .SetShapeFn([](InferenceContext* c) {
+      int num_features;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+      ShapeHandle feature_shape;
+      DimensionHandle unused_dim;
+      for (int i = 0; i < num_features; i++) {
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+        TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+                                    c->Dim(c->input(0), 0), &unused_dim));
+      }
+      // Bucketized result should have same dimension as input.
+      for (int i = 0; i < num_features; i++) {
+        c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+      }
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index cb0cb46..57c6bda 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11360,6 +11360,29 @@
   is_commutative: true
 }
 op {
+  name: "BoostedTreesBucketize"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "buckets"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
   name: "BoostedTreesCalculateBestGainsPerFeature"
   input_arg {
     name: "node_id_range"
@@ -11469,6 +11492,29 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesCreateQuantileStreamResource"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_streams"
+    type: DT_INT64
+  }
+  attr {
+    name: "max_elements"
+    type: "int"
+    default_value {
+      i: 1099511627776
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "BoostedTreesDeserializeEnsemble"
   input_arg {
     name: "tree_ensemble_handle"
@@ -11562,6 +11608,32 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesMakeQuantileSummaries"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
   name: "BoostedTreesMakeStatsSummary"
   input_arg {
     name: "node_ids"
@@ -11631,6 +11703,83 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesQuantileStreamResourceAddSummaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceFlush"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "num_buckets"
+    type: DT_INT64
+  }
+  attr {
+    name: "generate_quantiles"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "BoostedTreesSerializeEnsemble"
   input_arg {
     name: "tree_ensemble_handle"
@@ -13070,6 +13219,71 @@
   is_stateful: true
 }
 op {
+  name: "ConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "Conj"
   input_arg {
     name: "input"
@@ -27127,6 +27341,18 @@
   is_stateful: true
 }
 op {
+  name: "IsBoostedTreesQuantileStreamResourceInitialized"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
+op {
   name: "IsFinite"
   input_arg {
     name: "x"
@@ -29381,6 +29607,49 @@
   }
 }
 op {
+  name: "MapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
   name: "MapDefun"
   input_arg {
     name: "arguments"
@@ -34842,6 +35111,29 @@
   }
 }
 op {
+  name: "ModelDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
   name: "Mul"
   input_arg {
     name: "x"
@@ -35682,6 +35974,42 @@
   }
 }
 op {
+  name: "NonMaxSuppressionV2"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
   name: "NonMaxSuppressionV3"
   input_arg {
     name: "boxes"
@@ -35709,6 +36037,46 @@
   }
 }
 op {
+  name: "NonMaxSuppressionV3"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+}
+op {
   name: "NonMaxSuppressionV4"
   input_arg {
     name: "boxes"
@@ -35747,6 +36115,57 @@
   }
 }
 op {
+  name: "NonMaxSuppressionV4"
+  input_arg {
+    name: "boxes"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "scores"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "max_output_size"
+    type: DT_INT32
+  }
+  input_arg {
+    name: "iou_threshold"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "score_threshold"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "selected_indices"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "valid_outputs"
+    type: DT_INT32
+  }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
+    name: "pad_to_max_output_size"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+}
+op {
   name: "NonMaxSuppressionWithOverlaps"
   input_arg {
     name: "overlaps"
@@ -37037,6 +37456,54 @@
   }
 }
 op {
+  name: "ParallelInterleaveDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
   name: "ParallelMapDataset"
   input_arg {
     name: "input_dataset"
@@ -37118,6 +37585,53 @@
   }
 }
 op {
+  name: "ParallelMapDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+}
+op {
   name: "ParameterizedTruncatedNormal"
   input_arg {
     name: "shape"
@@ -64500,6 +65014,71 @@
   is_stateful: true
 }
 op {
+  name: "SparseConditionalAccumulator"
+  output_arg {
+    name: "handle"
+    type: DT_STRING
+    is_ref: true
+  }
+  attr {
+    name: "dtype"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_COMPLEX64
+        type: DT_INT64
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_BFLOAT16
+        type: DT_UINT16
+        type: DT_COMPLEX128
+        type: DT_HALF
+        type: DT_UINT32
+        type: DT_UINT64
+      }
+    }
+  }
+  attr {
+    name: "shape"
+    type: "shape"
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "SparseCross"
   input_arg {
     name: "indices"
@@ -69293,6 +69872,21 @@
   }
 }
 op {
+  name: "StaticRegexFullMatch"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+  attr {
+    name: "pattern"
+    type: "string"
+  }
+}
+op {
   name: "StaticRegexReplace"
   input_arg {
     name: "input"
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce..ffab8ad 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@
     .Attr("shape: shape")
     .Attr("container: string = ''")
     .Attr("shared_name: string = ''")
+    .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
     .SetIsStateful()
     .SetShapeFn([](InferenceContext* c) {
       c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@
     .Attr("shape: shape")
     .Attr("container: string = ''")
     .Attr("shared_name: string = ''")
+    .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
     .SetIsStateful()
     .SetShapeFn([](InferenceContext* c) {
       c->set_output(0, c->Vector(2));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index f03639e..7d9e7b2 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -198,6 +198,7 @@
     .Attr("Targuments: list(type) >= 0")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
+    .Attr("use_inter_op_parallelism: bool = true")
     .SetShapeFn(shape_inference::ScalarShape);
 
 REGISTER_OP("ParallelMapDataset")
@@ -209,6 +210,7 @@
     .Attr("Targuments: list(type) >= 0")
     .Attr("output_types: list(type) >= 1")
     .Attr("output_shapes: list(shape) >= 1")
+    .Attr("use_inter_op_parallelism: bool = true")
     .SetShapeFn(shape_inference::ScalarShape);
 
 REGISTER_OP("MapAndBatchDataset")
@@ -325,6 +327,19 @@
     .Attr("output_shapes: list(shape) >= 1")
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("ParallelInterleaveDatasetV2")
+    .Input("input_dataset: variant")
+    .Input("other_arguments: Targuments")
+    .Input("cycle_length: int64")
+    .Input("block_length: int64")
+    .Input("num_parallel_calls: int64")
+    .Output("handle: variant")
+    .Attr("f: func")
+    .Attr("Targuments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape);
+
 REGISTER_OP("GroupByReducerDataset")
     .Input("input_dataset: variant")
     .Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -858,6 +873,13 @@
     .Attr("output_shapes: list(shape) >= 1")
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("ModelDataset")
+    .Input("input_dataset: variant")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape);
+
 REGISTER_OP("MapDefun")
     .Input("arguments: Targuments")
     .Output("output: output_types")
@@ -866,7 +888,7 @@
     .Attr("output_shapes: list(shape) >= 1")
     .Attr("f: func")
     .SetShapeFn([](shape_inference::InferenceContext* c) {
-      std::vector<TensorShape> output_shapes;
+      std::vector<PartialTensorShape> output_shapes;
       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
       if (output_shapes.size() != c->num_outputs()) {
         return errors::InvalidArgument(
@@ -876,6 +898,10 @@
 
       int64 dim_zero = -1;
       for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+        if (c->Rank(c->input(i)) == 0) {
+          return errors::InvalidArgument(
+              "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+        }
         auto dim_handle = c->Dim(c->input(i), 0);
         if (c->ValueKnown(dim_handle)) {
           if (dim_zero == -1) {
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 11ca0bd..5427275 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -683,11 +683,12 @@
     });
 
 REGISTER_OP("NonMaxSuppressionV2")
-    .Input("boxes: float")
-    .Input("scores: float")
+    .Input("boxes: T")
+    .Input("scores: T")
     .Input("max_output_size: int32")
     .Input("iou_threshold: float")
     .Output("selected_indices: int32")
+    .Attr("T: {half, float} = DT_FLOAT")
     .SetShapeFn([](InferenceContext* c) {
       // Get inputs and validate ranks.
       ShapeHandle boxes;
@@ -711,22 +712,24 @@
     });
 
 REGISTER_OP("NonMaxSuppressionV3")
-    .Input("boxes: float")
-    .Input("scores: float")
+    .Input("boxes: T")
+    .Input("scores: T")
     .Input("max_output_size: int32")
     .Input("iou_threshold: float")
     .Input("score_threshold: float")
     .Output("selected_indices: int32")
+    .Attr("T: {half, float} = DT_FLOAT")
     .SetShapeFn(NMSShapeFn);
 
 REGISTER_OP("NonMaxSuppressionV4")
-    .Input("boxes: float")
-    .Input("scores: float")
+    .Input("boxes: T")
+    .Input("scores: T")
     .Input("max_output_size: int32")
     .Input("iou_threshold: float")
     .Input("score_threshold: float")
     .Output("selected_indices: int32")
     .Output("valid_outputs: int32")
+    .Attr("T: {half, float} = DT_FLOAT")
     .Attr("pad_to_max_output_size: bool = false")
     .SetShapeFn([](InferenceContext* c) {
       TF_RETURN_IF_ERROR(NMSShapeFn(c));
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4419f93..190f6aa 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4272,6 +4272,29 @@
   is_commutative: true
 }
 op {
+  name: "BoostedTreesBucketize"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  output_arg {
+    name: "buckets"
+    type: DT_INT32
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
   name: "BoostedTreesCalculateBestGainsPerFeature"
   input_arg {
     name: "node_id_range"
@@ -4381,6 +4404,29 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesCreateQuantileStreamResource"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "num_streams"
+    type: DT_INT64
+  }
+  attr {
+    name: "max_elements"
+    type: "int"
+    default_value {
+      i: 1099511627776
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "BoostedTreesDeserializeEnsemble"
   input_arg {
     name: "tree_ensemble_handle"
@@ -4474,6 +4520,32 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesMakeQuantileSummaries"
+  input_arg {
+    name: "float_values"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  input_arg {
+    name: "example_weights"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "epsilon"
+    type: DT_FLOAT
+  }
+  output_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+}
+op {
   name: "BoostedTreesMakeStatsSummary"
   input_arg {
     name: "node_ids"
@@ -4543,6 +4615,83 @@
   is_stateful: true
 }
 op {
+  name: "BoostedTreesQuantileStreamResourceAddSummaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "summaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceFlush"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "num_buckets"
+    type: DT_INT64
+  }
+  attr {
+    name: "generate_quantiles"
+    type: "bool"
+    default_value {
+      b: false
+    }
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "bucket_boundaries"
+    type: DT_FLOAT
+    number_attr: "num_features"
+  }
+  attr {
+    name: "num_features"
+    type: "int"
+    has_minimum: true
+  }
+  is_stateful: true
+}
+op {
+  name: "BoostedTreesQuantileStreamResourceHandleOp"
+  output_arg {
+    name: "resource"
+    type: DT_RESOURCE
+  }
+  attr {
+    name: "container"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "shared_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  is_stateful: true
+}
+op {
   name: "BoostedTreesSerializeEnsemble"
   input_arg {
     name: "tree_ensemble_handle"
@@ -5592,6 +5741,19 @@
       s: ""
     }
   }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
   is_stateful: true
 }
 op {
@@ -13149,6 +13311,18 @@
   is_stateful: true
 }
 op {
+  name: "IsBoostedTreesQuantileStreamResourceInitialized"
+  input_arg {
+    name: "quantile_stream_resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "is_initialized"
+    type: DT_BOOL
+  }
+  is_stateful: true
+}
+op {
   name: "IsFinite"
   input_arg {
     name: "x"
@@ -14542,6 +14716,13 @@
     has_minimum: true
     minimum: 1
   }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
 }
 op {
   name: "MapDefun"
@@ -16540,6 +16721,29 @@
   }
 }
 op {
+  name: "ModelDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
   name: "Mul"
   input_arg {
     name: "x"
@@ -17078,11 +17282,11 @@
   name: "NonMaxSuppressionV2"
   input_arg {
     name: "boxes"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "scores"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "max_output_size"
@@ -17096,16 +17300,29 @@
     name: "selected_indices"
     type: DT_INT32
   }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
 }
 op {
   name: "NonMaxSuppressionV3"
   input_arg {
     name: "boxes"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "scores"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "max_output_size"
@@ -17123,16 +17340,29 @@
     name: "selected_indices"
     type: DT_INT32
   }
+  attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
 }
 op {
   name: "NonMaxSuppressionV4"
   input_arg {
     name: "boxes"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "scores"
-    type: DT_FLOAT
+    type_attr: "T"
   }
   input_arg {
     name: "max_output_size"
@@ -17155,6 +17385,19 @@
     type: DT_INT32
   }
   attr {
+    name: "T"
+    type: "type"
+    default_value {
+      type: DT_FLOAT
+    }
+    allowed_values {
+      list {
+        type: DT_HALF
+        type: DT_FLOAT
+      }
+    }
+  }
+  attr {
     name: "pad_to_max_output_size"
     type: "bool"
     default_value {
@@ -18192,6 +18435,54 @@
   }
 }
 op {
+  name: "ParallelInterleaveDatasetV2"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_calls"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
+op {
   name: "ParallelMapDataset"
   input_arg {
     name: "input_dataset"
@@ -18230,6 +18521,13 @@
     has_minimum: true
     minimum: 1
   }
+  attr {
+    name: "use_inter_op_parallelism"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
 }
 op {
   name: "ParameterizedTruncatedNormal"
@@ -29610,6 +29908,19 @@
       s: ""
     }
   }
+  attr {
+    name: "reduction_type"
+    type: "string"
+    default_value {
+      s: "MEAN"
+    }
+    allowed_values {
+      list {
+        s: "MEAN"
+        s: "SUM"
+      }
+    }
+  }
   is_stateful: true
 }
 op {
@@ -32108,6 +32419,21 @@
   }
 }
 op {
+  name: "StaticRegexFullMatch"
+  input_arg {
+    name: "input"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "output"
+    type: DT_BOOL
+  }
+  attr {
+    name: "pattern"
+    type: "string"
+  }
+}
+op {
   name: "StaticRegexReplace"
   input_arg {
     name: "input"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 79ca96d..eff4532 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -343,10 +343,11 @@
       // Validate the record_defaults inputs.
       for (int i = 1; i < c->num_inputs(); ++i) {
         ShapeHandle v;
-        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
-        if (c->Value(c->Dim(v, 0)) > 1) {
+        TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+        if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
           return errors::InvalidArgument(
-              "Shape of a default must be a length-0 or length-1 vector");
+              "Shape of a default must be a length-0 or length-1 vector, or a "
+              "scalar.");
         }
       }
 
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index c65e66d..ba594e4 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -52,9 +52,12 @@
   INFER_OK(op, "[1,2,?,4];?;?", "in0;in0");
   INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0");
 
+  // Scalar defaults are ok
+  INFER_OK(op, "?;?;[]", "in0;in0");
+
   // Check errors in the record_defaults inputs.
-  INFER_ERROR("must be rank 1", op, "?;?;[]");
-  INFER_ERROR("must be rank 1", op, "?;[];?");
+  INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]");
+  INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?");
   INFER_ERROR("Shape of a default must be", op, "?;?;[2]");
   INFER_ERROR("Shape of a default must be", op, "?;[2];?");
 }
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 7aa1e71..ef8b15d 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -56,6 +56,12 @@
       return Status::OK();
     });
 
+REGISTER_OP("StaticRegexFullMatch")
+    .Input("input: string")
+    .Attr("pattern: string")
+    .Output("output: bool")
+    .SetShapeFn(shape_inference::UnchangedShape);
+
 REGISTER_OP("StringToHashBucketFast")
     .Input("input: string")
     .Output("output: int64")
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index e597a49..d7a13a3 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -37,13 +37,13 @@
 namespace tensorflow {
 namespace port {
 
-std::string MaybeAbiDemangle(const char* name) {
+string MaybeAbiDemangle(const char* name) {
 #if defined(_MSC_VER)
   std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
                                             std::free,
                                             static_cast<unsigned short>(0))};
 
-  return std::string(demangled.get() != nullptr ? demangled.get() : name);
+  return string(demangled.get() != nullptr ? demangled.get() : name);
 #else
   int status = 0;
   std::unique_ptr<char, void (*)(void*)> res{
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 591e83b..d1498a6 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -17,11 +17,12 @@
 #define TENSORFLOW_CORE_PLATFORM_ABI_H_
 
 #include <string>
+#include "tensorflow/core/platform/types.h"
 
 namespace tensorflow {
 namespace port {
 
-std::string MaybeAbiDemangle(const char* name);
+string MaybeAbiDemangle(const char* name);
 
 }  // namespace port
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index a1be4aa..5e1eabe 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -394,9 +394,9 @@
           .StopCapture()
           .OneLiteral(": ")
           .GetResult(&value, &name)) {
-    string str_value = std::string(value);
+    string str_value(value);
     str_util::StripTrailingWhitespace(&str_value);
-    that->response_headers_[std::string(name)] = str_value;
+    that->response_headers_[string(name)] = str_value;
   }
   return size * nmemb;
 }
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 9d33787..83228fa 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -179,13 +179,13 @@
     return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
                                    fname);
   }
-  *bucket = std::string(bucketp);
+  *bucket = string(bucketp);
   if (bucket->empty() || *bucket == ".") {
     return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
                                    fname);
   }
   str_util::ConsumePrefix(&objectp, "/");
-  *object = std::string(objectp);
+  *object = string(objectp);
   if (!empty_object_ok && object->empty()) {
     return errors::InvalidArgument("GCS path doesn't contain an object name: ",
                                    fname);
@@ -224,7 +224,7 @@
   for (const string& path : paths) {
     StringPiece subpath = io::Dirname(path);
     while (!subpath.empty()) {
-      result.emplace(std::string(subpath));
+      result.emplace(string(subpath));
       subpath = io::Dirname(subpath);
     }
   }
@@ -371,7 +371,7 @@
 
   ~GcsWritableFile() override { Close().IgnoreError(); }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     TF_RETURN_IF_ERROR(CheckWritable());
     sync_needed_ = true;
     outfile_ << data;
@@ -723,7 +723,7 @@
 
       if (!header_name.empty() && !header_value.empty()) {
         additional_header_.reset(new std::pair<const string, const string>(
-            std::string(header_name), std::string(header_value)));
+            string(header_name), string(header_value)));
 
         VLOG(1) << "GCS additional header ENABLED. "
                 << "Name: " << additional_header_->first << ", "
@@ -1229,7 +1229,7 @@
         // Find the fixed prefix by looking for the first wildcard.
         const string& fixed_prefix =
             pattern.substr(0, pattern.find_first_of("*?[\\"));
-        const string& dir = std::string(io::Dirname(fixed_prefix));
+        const string dir(io::Dirname(fixed_prefix));
         if (dir.empty()) {
           return errors::InvalidArgument(
               "A GCS pattern doesn't have a bucket name: ", pattern);
@@ -1326,7 +1326,7 @@
               " doesn't match the prefix ", object_prefix));
         }
         if (!relative_path.empty() || include_self_directory_marker) {
-          result->emplace_back(std::string(relative_path));
+          result->emplace_back(relative_path);
         }
         if (++retrieved_results >= max_results) {
           return Status::OK();
@@ -1354,7 +1354,7 @@
               "Unexpected response: the returned folder name ", prefix_str,
               " doesn't match the prefix ", object_prefix);
         }
-        result->emplace_back(std::string(relative_path));
+        result->emplace_back(relative_path);
         if (++retrieved_results >= max_results) {
           return Status::OK();
         }
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index ee6ba7b..9b85cae 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -216,7 +216,7 @@
   // Send the request to the Google OAuth 2.0 server to get the token.
   std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
   std::vector<char> response_buffer;
-  request->SetUri(std::string(oauth_server_uri));
+  request->SetUri(string(oauth_server_uri));
   request->SetPostFromBuffer(request_body.c_str(), request_body.size());
   request->SetResultBuffer(&response_buffer);
   TF_RETURN_IF_ERROR(request->Send());
@@ -248,7 +248,7 @@
 
   std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
   std::vector<char> response_buffer;
-  request->SetUri(std::string(oauth_server_uri));
+  request->SetUri(string(oauth_server_uri));
   request->SetPostFromBuffer(request_body.c_str(), request_body.size());
   request->SetResultBuffer(&response_buffer);
   TF_RETURN_IF_ERROR(request->Send());
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 4ffa722..1cd0641 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -126,9 +126,9 @@
   EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
             grant_type);
 
-  int last_dot = std::string(assertion).find_last_of(".");
-  string header_dot_claim = std::string(assertion.substr(0, last_dot));
-  string signature_encoded = std::string(assertion.substr(last_dot + 1));
+  int last_dot = assertion.rfind('.');
+  string header_dot_claim(assertion.substr(0, last_dot));
+  string signature_encoded(assertion.substr(last_dot + 1));
 
   // Check that 'signature' signs 'header_dot_claim'.
 
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 92aa72b..941ab7a 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -177,7 +177,7 @@
     Close().IgnoreError();
   }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     return RetryingUtils::CallWithRetries(
         [this, &data]() { return base_file_->Append(data); },
         initial_delay_microseconds_);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index ec2c470..5910fef 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -72,7 +72,7 @@
 class MockWritableFile : public WritableFile {
  public:
   explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {}
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     return calls_.ConsumeNextCall("Append");
   }
   Status Close() override { return calls_.ConsumeNextCall("Close"); }
diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h
new file mode 100644
index 0000000..7c5c665
--- /dev/null
+++ b/tensorflow/core/platform/cord.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_CORD_H_
+
+// Include appropriate platform-dependent implementations
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/cord.h"
+#else
+#include "tensorflow/core/platform/default/cord.h"
+#endif
+
+#endif  // TENSORFLOW_CORE_PLATFORM_CORD_H_
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 07b2e34..bb841ae 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -625,6 +625,7 @@
     """Additional dependencies needed to build TF libraries."""
     return [
         "@com_google_absl//absl/base:base",
+        "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/types:span",
         "@com_google_absl//absl/types:optional",
     ] + if_static(
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
new file mode 100644
index 0000000..1ab6821
--- /dev/null
+++ b/tensorflow/core/platform/default/cord.h
@@ -0,0 +1,24 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+
+class Cord;
+namespace absl {
+using ::Cord;
+}  // namespace absl
+
+#endif  // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index ccddf1e..0389149 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,6 +321,11 @@
     return nullptr;
   }
 
+  bool IsEnabled(bool is_expensive) const override {
+    // We don't do anything with 'Activities' so we are never 'enabled'.
+    return false;
+  }
+
  protected:
   // This callback is used exclusively by CUPTIManager.
   friend class CUPTIManager;
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 305a9a6..2e32abd 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/null_file_system.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/platform/test.h"
@@ -345,7 +346,13 @@
   // Write something to the temporary file.
   std::unique_ptr<WritableFile> file_to_write;
   TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
+#if defined(PLATFORM_GOOGLE)
+  TF_CHECK_OK(file_to_write->Append("Nu"));
+  TF_CHECK_OK(file_to_write->Append(absl::Cord("ll")));
+#else
+  // TODO(ebrevdo): Remove this version.
   TF_CHECK_OK(file_to_write->Append("Null"));
+#endif
   TF_CHECK_OK(file_to_write->Close());
   TF_CHECK_OK(env->FileExists(filename));
 
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 077b1d7..30059dc 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -24,6 +24,7 @@
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/cord.h"
 #include "tensorflow/core/platform/file_statistics.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/platform.h"
@@ -252,7 +253,12 @@
   virtual ~WritableFile();
 
   /// \brief Append 'data' to the file.
-  virtual Status Append(const StringPiece& data) = 0;
+  virtual Status Append(StringPiece data) = 0;
+
+  // \brief Append 'data' to the file.
+  virtual Status Append(const absl::Cord& cord) {
+    return errors::Unimplemented("Append(absl::Cord) is not implemented");
+  }
 
   /// \brief Close the file.
   ///
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 8cdb08f..eb35531 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -282,7 +282,7 @@
     }
   }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     if (hdfs_->hdfsWrite(fs_, file_, data.data(),
                          static_cast<tSize>(data.size())) == -1) {
       return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 47bfa02..c7afab9 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -91,7 +91,7 @@
     }
   }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     size_t r = fwrite(data.data(), 1, data.size(), file_);
     if (r != data.size()) {
       return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index ce0f6cd..e0b8e37 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -211,7 +211,7 @@
             std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
                 std::ios_base::out)) {}
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     if (!outfile_) {
       return errors::FailedPrecondition(
           "The internal temporary file is not writable.");
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index e5851f1..9974bbb 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -155,6 +155,10 @@
       StringPiece name_part1, StringPiece name_part2,
       bool is_expensive) const = 0;
 
+  // Returns true if this activity handle tracking is enabled for an op of the
+  // given expensiveness.
+  virtual bool IsEnabled(bool is_expensive) const = 0;
+
  protected:
   static string ConcatenateNames(StringPiece first, StringPiece second);
 
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 9079a5c..6cf7963 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -150,7 +150,7 @@
     }
   }
 
-  Status Append(const StringPiece& data) override {
+  Status Append(StringPiece data) override {
     DWORD bytes_written = 0;
     DWORD data_size = static_cast<DWORD>(data.size());
     BOOL write_result =
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index da3a995..625d564 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -390,9 +390,12 @@
   message Experimental {
     // Task name for group resolution.
     string collective_group_leader = 1;
-    // Whether the client will format templated errors. For example, the string:
-    // "The node was defined on ^^node:Foo:${file}:${line}^^".
-    bool client_handles_error_formatting = 2;
+
+    // We removed the flag client_handles_error_formatting. Marking the tag
+    // number as reserved.
+    // TODO(shikharagarwal): Should we just remove this tag so that it can be
+    // used in future for other purpose?
+    reserved 2;
 
     // Which executor to use, the default executor will be used
     // if it is an empty string or "DEFAULT"
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index 973e315..24002e7 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// LINT.IfChange
 
 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h
index 1a622ba..1e45a8a 100644
--- a/tensorflow/core/util/ctc/ctc_beam_scorer.h
+++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// LINT.IfChange
 
 // Collection of scoring classes that can be extended and provided to the
 // CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index 5e2aeb7..6fbb1ed 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -12,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// LINT.IfChange
 
 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
 #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index 3be3682..b55d7d7 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// LINT.IfChange
 
 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
 #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h
index 36be9e92..054412d 100644
--- a/tensorflow/core/util/ctc/ctc_loss_util.h
+++ b/tensorflow/core/util/ctc/ctc_loss_util.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+// LINT.IfChange
 
 #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
 #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933..546b0a8 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@
 
 void GroupIterable::IteratorStep::UpdateEndOfGroup() {
   ++next_loc_;
-  int64 N = iter_->ix_.dim_size(0);
-  auto ix_t = iter_->ix_.template matrix<int64>();
+  const auto& ix_t = iter_->ix_matrix_;
+  const int64 N = ix_t.dimension(0);
   while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
     ++next_loc_;
   }
@@ -54,7 +54,7 @@
 
 std::vector<int64> Group::group() const {
   std::vector<int64> g;
-  auto ix_t = iter_->ix_.template matrix<int64>();
+  const auto& ix_t = iter_->ix_matrix_;
   for (const int d : iter_->group_dims_) {
     g.push_back(ix_t(loc_, d));
   }
@@ -62,8 +62,8 @@
 }
 
 TTypes<int64>::UnalignedConstMatrix Group::indices() const {
-  return TTypes<int64>::UnalignedConstMatrix(
-      &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+  return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+                                             next_loc_ - loc_, iter_->dims_);
 }
 
 }  // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6..14610c6 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@
 
   GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
       : ix_(ix),
+        ix_matrix_(ix_.matrix<int64>()),
         vals_(vals),
         dims_(dims),
         group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@
 
  private:
   friend class Group;
-  Tensor ix_;
+  const Tensor ix_;
+  const TTypes<int64>::ConstMatrix ix_matrix_;
   Tensor vals_;
   const int dims_;
   const gtl::InlinedVector<int64, 8> group_dims_;
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index dac9b7a..82bc3ff 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -121,10 +121,6 @@
 2.  The Android NDK is required to build the native (C/C++) TensorFlow code. The
     current recommended version is 14b, which may be found
     [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
-
-      * NDK 16, the revision released in November 2017, is **incompatible** with
-        Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918).
-
 3.  The Android SDK and build tools may be obtained
     [here](https://developer.android.com/tools/revisions/build-tools.html), or
     alternatively as part of [Android
@@ -132,10 +128,6 @@
     23 is required to build the TF Android demo (though it will run on API >= 21
     devices).
 
-      - The Android Studio SDK Manager's NDK installer will install the latest
-        revision of the NDK, which is **incompatible** with Bazel. You'll need
-        to download an older version manually, as (2) suggests.
-
 ##### Edit WORKSPACE
 
 NOTE: As long as you have the SDK and NDK installed, the `./configure` script
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD
similarity index 100%
rename from tensorflow/contrib/autograph/examples/integration_tests/BUILD
rename to tensorflow/examples/autograph/integration_tests/BUILD
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
similarity index 81%
rename from tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
rename to tensorflow/examples/autograph/integration_tests/errors_test.py
index 04a968b..69e5936 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -20,21 +20,18 @@
 
 import tensorflow as tf
 
-from tensorflow.contrib import autograph as ag
-from tensorflow.python.util import tf_inspect
+from tensorflow.python import autograph as ag
 
 
 class ErrorsTest(tf.test.TestCase):
 
   def test_graph_construction_error_rewriting_call_tree(self):
 
-    def innermost(x):
-      if x > 0:
-        return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
-      return tf.zeros((2, 3))
+    def test_fn():
+      return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
 
     def inner_caller():
-      return innermost(1.0)
+      return test_fn()
 
     def caller():
       return inner_caller()
@@ -45,23 +42,21 @@
     expected = error.exception
     custom_traceback = expected.custom_traceback
     found_correct_filename = False
-    num_innermost_names = 0
+    num_test_fn_names = 0
     num_inner_caller_names = 0
     num_caller_names = 0
-    ag_output_filename = tf_inspect.getsourcefile(graph)
     for frame in custom_traceback:
       filename, _, fn_name, _ = frame
-      self.assertFalse('control_flow_ops.py' in filename)
-      self.assertFalse(ag_output_filename in filename)
+      self.assertFalse('/tmp/' in filename)
       found_correct_filename |= __file__ in filename
       self.assertNotEqual('tf__test_fn', fn_name)
-      num_innermost_names += int('innermost' == fn_name)
+      num_test_fn_names += int('test_fn' == fn_name)
       self.assertNotEqual('tf__inner_caller', fn_name)
       num_inner_caller_names += int('inner_caller' == fn_name)
       self.assertNotEqual('tf__caller', fn_name)
       num_caller_names += int('caller' == fn_name)
     self.assertTrue(found_correct_filename)
-    self.assertEqual(num_innermost_names, 1)
+    self.assertEqual(num_test_fn_names, 1)
     self.assertEqual(num_inner_caller_names, 1)
     self.assertEqual(num_caller_names, 1)
 
@@ -97,7 +92,7 @@
     compiled_fn = ag.to_graph(test_fn)
 
     with self.assertRaises(ag.TfRuntimeError) as error:
-      with self.cached_session() as sess:
+      with self.test_session() as sess:
         x = compiled_fn(tf.constant([4, 8]))
         with ag.improved_errors(compiled_fn):
           sess.run(x)
@@ -106,19 +101,14 @@
     found_correct_filename = False
     num_test_fn_frames = 0
     num_g_frames = 0
-    ag_output_filename = tf_inspect.getsourcefile(compiled_fn)
     for frame in custom_traceback:
       filename, _, fn_name, source_code = frame
-      self.assertFalse(ag_output_filename in filename)
-      self.assertFalse('control_flow_ops.py' in filename)
+      self.assertFalse('/tmp/' in filename)
+      self.assertFalse('control_flow.py' in filename)
       self.assertFalse('ag__.' in fn_name)
-      self.assertFalse('tf__g' in fn_name)
-      self.assertFalse('tf__test_fn' in fn_name)
       found_correct_filename |= __file__ in filename
       num_test_fn_frames += int('test_fn' == fn_name and
                                 'return g(x, 10)' in source_code)
-      # This makes sure that the code is correctly rewritten from "x_1 //= 0" to
-      # "x //= 0".
       num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
     self.assertTrue(found_correct_filename)
     self.assertEqual(num_test_fn_frames, 1)
@@ -144,7 +134,7 @@
     # frame with "g" as the function name but because we don't yet add
     # try/except blocks to inner functions the name is "tf__g".
     with self.assertRaises(ag.TfRuntimeError) as error:
-      with self.cached_session() as sess:
+      with self.test_session() as sess:
         x = compiled_fn(tf.constant([4, 8]))
         with ag.improved_errors(compiled_fn):
           sess.run(x)
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
similarity index 98%
rename from tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
rename to tensorflow/examples/autograph/integration_tests/keras_test.py
index 7e7ef5a..dca7c07 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -20,7 +20,7 @@
 
 import tensorflow as tf
 
-from tensorflow.contrib import autograph
+from tensorflow.python import autograph
 
 
 class MinimalKeras(tf.keras.Model):
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
rename to tensorflow/examples/autograph/integration_tests/list_literals_test.py
index 904246a..917f5ff 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
+++ b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
@@ -20,7 +20,7 @@
 
 import tensorflow as tf
 
-from tensorflow.contrib import autograph as ag
+from tensorflow.python import autograph as ag
 
 
 def list_used_as_tuple():
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index c8de6c2..0c7ca9b 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -25,7 +25,7 @@
 class FreezeTest(test.TestCase):
 
   def testCreateInferenceGraphWithMfcc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       freeze.create_inference_graph(
           wanted_words='a,b,c,d',
           sample_rate=16000,
@@ -44,7 +44,7 @@
       self.assertEqual(1, ops.count('Mfcc'))
 
   def testCreateInferenceGraphWithoutMfcc(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       freeze.create_inference_graph(
           wanted_words='a,b,c,d',
           sample_rate=16000,
@@ -63,7 +63,7 @@
       self.assertEqual(0, ops.count('Mfcc'))
 
   def testFeatureBinCount(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       freeze.create_inference_graph(
           wanted_words='a,b,c,d',
           sample_rate=16000,
diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py
index 2e551be..aa4e807 100644
--- a/tensorflow/examples/speech_commands/input_data_test.py
+++ b/tensorflow/examples/speech_commands/input_data_test.py
@@ -32,7 +32,7 @@
 class InputDataTest(test.TestCase):
 
   def _getWavData(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sample_data = tf.zeros([32000, 2])
       wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
       wav_data = sess.run(wav_encoder)
@@ -75,7 +75,7 @@
       self._saveTestWavFile(file_path, wav_data)
     model_settings = models.prepare_model_settings(
         4, 16000, 1000, window_length_ms, 20, 40, preprocess)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       audio_processor = input_data.AudioProcessor(
           "", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
       result_data, result_labels = audio_processor.get_data(
diff --git a/tensorflow/examples/speech_commands/label_wav_test.py b/tensorflow/examples/speech_commands/label_wav_test.py
index 80ca774..f0af2a4 100644
--- a/tensorflow/examples/speech_commands/label_wav_test.py
+++ b/tensorflow/examples/speech_commands/label_wav_test.py
@@ -30,7 +30,7 @@
 class LabelWavTest(test.TestCase):
 
   def _getWavData(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sample_data = tf.zeros([1000, 2])
       wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
       wav_data = sess.run(wav_encoder)
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 0c37396..04478c0 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -49,7 +49,7 @@
 
   def testCreateModelConvTraining(self):
     model_settings = self._modelSettings()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       logits, dropout_prob = models.create_model(fingerprint_input,
                                                  model_settings, "conv", True)
@@ -60,7 +60,7 @@
 
   def testCreateModelConvInference(self):
     model_settings = self._modelSettings()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       logits = models.create_model(fingerprint_input, model_settings, "conv",
                                    False)
@@ -69,7 +69,7 @@
 
   def testCreateModelLowLatencyConvTraining(self):
     model_settings = self._modelSettings()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       logits, dropout_prob = models.create_model(
           fingerprint_input, model_settings, "low_latency_conv", True)
@@ -80,7 +80,7 @@
 
   def testCreateModelFullyConnectedTraining(self):
     model_settings = self._modelSettings()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       logits, dropout_prob = models.create_model(
           fingerprint_input, model_settings, "single_fc", True)
@@ -91,7 +91,7 @@
 
   def testCreateModelBadArchitecture(self):
     model_settings = self._modelSettings()
-    with self.test_session():
+    with self.cached_session():
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       with self.assertRaises(Exception) as e:
         models.create_model(fingerprint_input, model_settings,
@@ -100,7 +100,7 @@
 
   def testCreateModelTinyConvTraining(self):
     model_settings = self._modelSettings()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
       logits, dropout_prob = models.create_model(
           fingerprint_input, model_settings, "tiny_conv", True)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 5ebd409..322b35d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3401,51 +3401,84 @@
 	return op.Output(0)
 }
 
-// Computes the mean along sparse segments of a tensor.
+// Runs multiple additive regression ensemble predictors on input instances and
 //
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
 //
 // Arguments:
 //
-//	indices: A 1-D tensor. Has same rank as `segment_ids`.
-//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//	cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+//	cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+//	bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+//	logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
 //
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+	opspec := tf.OpSpec{
+		Type: "BoostedTreesTrainingPredict",
+		Input: []tf.Input{
+			tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Serializes the tree ensemble to a proto.
+//
+// Arguments:
+//	tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
+func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
 	opspec := tf.OpSpec{
-		Type: "SparseSegmentMean",
+		Type: "BoostedTreesSerializeEnsemble",
 		Input: []tf.Input{
-			data, indices, segment_ids,
+			tree_ensemble_handle,
 		},
 	}
 	op := scope.AddOperation(opspec)
-	return op.Output(0)
+	return op.Output(0), op.Output(1)
 }
 
-// Pop the element at the top of the stack.
+// Debugging/model interpretability outputs for each example.
+//
+// It traverses all the trees and computes debug metrics for individual examples,
+// such as getting split feature ids and logits after each split along the decision
+// path used to compute directional feature contributions.
 //
 // Arguments:
-//	handle: The handle to a stack.
-//	elem_type: The type of the elem that is popped.
 //
-// Returns The tensor that is popped from the top of the stack.
-func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+//	bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+//	logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
+// examples_debug_outputs_serialized.
+//
+// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
+func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
 	if scope.Err() != nil {
 		return
 	}
-	attrs := map[string]interface{}{"elem_type": elem_type}
+	attrs := map[string]interface{}{"logits_dimension": logits_dimension}
 	opspec := tf.OpSpec{
-		Type: "StackPopV2",
+		Type: "BoostedTreesExampleDebugOutputs",
 		Input: []tf.Input{
-			handle,
+			tree_ensemble_handle, tf.OutputList(bucketized_features),
 		},
 		Attrs: attrs,
 	}
@@ -8159,47 +8192,6 @@
 	return op.Output(0)
 }
 
-// RandomPoissonAttr is an optional argument to RandomPoisson.
-type RandomPoissonAttr func(optionalAttr)
-
-// RandomPoissonSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed(value int64) RandomPoissonAttr {
-	return func(m optionalAttr) {
-		m["seed"] = value
-	}
-}
-
-// RandomPoissonSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed2(value int64) RandomPoissonAttr {
-	return func(m optionalAttr) {
-		m["seed2"] = value
-	}
-}
-
-// Use RandomPoissonV2 instead.
-//
-// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
-func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "RandomPoisson",
-		Input: []tf.Input{
-			shape, rate,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Returns the element-wise sum of a list of tensors.
 //
 // `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
@@ -8348,6 +8340,377 @@
 	return op.Output(0)
 }
 
+// Returns the truth value of (x > y) element-wise.
+//
+// *NOTE*: `Greater` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Greater",
+		Input: []tf.Input{
+			x, y,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
+type ResourceSparseApplyRMSPropAttr func(optionalAttr)
+
+// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
+	return func(m optionalAttr) {
+		m["use_locking"] = value
+	}
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+//	var_: Should be from a Variable().
+//	ms: Should be from a Variable().
+//	mom: Should be from a Variable().
+//	lr: Scaling factor. Must be a scalar.
+//	rho: Decay rate. Must be a scalar.
+//
+//	epsilon: Ridge term. Must be a scalar.
+//	grad: The gradient.
+//	indices: A vector of indices into the first dimension of var, ms and mom.
+//
+// Returns the created operation.
+func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "ResourceSparseApplyRMSProp",
+		Input: []tf.Input{
+			var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
+		},
+		Attrs: attrs,
+	}
+	return scope.AddOperation(opspec)
+}
+
+// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
+type SampleDistortedBoundingBoxAttr func(optionalAttr)
+
+// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to non-zero, the random number
+// generator is seeded by the given `seed`.  Otherwise, it is seeded by a random
+// seed.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["seed"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["seed2"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
+//
+// value: The cropped area of the image must contain at least this
+// fraction of any bounding box supplied. The value of this parameter should be
+// non-negative. In the case of 0, the cropped area does not need to overlap
+// any of the bounding boxes supplied.
+// If not specified, defaults to 0.1
+func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["min_object_covered"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
+//
+// value: The cropped area of the image must have an aspect ratio =
+// width / height within this range.
+// If not specified, defaults to <f:0.75 f:1.33 >
+func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["aspect_ratio_range"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
+//
+// value: The cropped area of the image must contain a fraction of the
+// supplied image within this range.
+// If not specified, defaults to <f:0.05 f:1 >
+func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["area_range"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
+//
+// value: Number of attempts at generating a cropped region of the image
+// of the specified constraints. After `max_attempts` failures, return the entire
+// image.
+// If not specified, defaults to 100
+func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["max_attempts"] = value
+	}
+}
+
+// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
+//
+// value: Controls behavior if no bounding boxes supplied.
+// If true, assume an implicit bounding box covering the whole input. If false,
+// raise an error.
+// If not specified, defaults to false
+func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
+	return func(m optionalAttr) {
+		m["use_image_if_no_bounding_boxes"] = value
+	}
+}
+
+// Generate a single randomly distorted bounding box for an image.
+//
+// Bounding box annotations are often supplied in addition to ground-truth labels
+// in image recognition or object localization tasks. A common technique for
+// training such a system is to randomly distort an image while preserving
+// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
+// localization of an object, i.e. bounding box, given an `image_size`,
+// `bounding_boxes` and a series of constraints.
+//
+// The output of this Op is a single bounding box that may be used to crop the
+// original image. The output is returned as 3 tensors: `begin`, `size` and
+// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
+// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+// what the bounding box looks like.
+//
+// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example,
+//
+// ```python
+//     # Generate a single distorted bounding box.
+//     begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+//         tf.shape(image),
+//         bounding_boxes=bounding_boxes)
+//
+//     # Draw the bounding box in an image summary.
+//     image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+//                                                   bbox_for_draw)
+//     tf.summary.image('images_with_box', image_with_box)
+//
+//     # Employ the bounding box to distort the image.
+//     distorted_image = tf.slice(image, begin, size)
+// ```
+//
+// Note that if no bounding box information is available, setting
+// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
+// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
+// false and no bounding boxes are supplied, an error is raised.
+//
+// Arguments:
+//	image_size: 1-D, containing `[height, width, channels]`.
+//	bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
+// associated with the image.
+//
+// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
+// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+// Provide as input to `tf.image.draw_bounding_boxes`.
+func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "SampleDistortedBoundingBox",
+		Input: []tf.Input{
+			image_size, bounding_boxes,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes sigmoid of `x` element-wise.
+//
+// Specifically, `y = 1 / (1 + exp(-x))`.
+func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "Sigmoid",
+		Input: []tf.Input{
+			x,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
+type FusedBatchNormAttr func(optionalAttr)
+
+// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
+//
+// value: A small float number added to the variance of x.
+// If not specified, defaults to 0.0001
+func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
+	return func(m optionalAttr) {
+		m["epsilon"] = value
+	}
+}
+
+// FusedBatchNormDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
+// If not specified, defaults to "NHWC"
+func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
+	return func(m optionalAttr) {
+		m["data_format"] = value
+	}
+}
+
+// FusedBatchNormIsTraining sets the optional is_training attribute to value.
+//
+// value: A bool value to indicate the operation is for training (default)
+// or inference.
+// If not specified, defaults to true
+func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
+	return func(m optionalAttr) {
+		m["is_training"] = value
+	}
+}
+
+// Batch normalization.
+//
+// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+// The size of 1D Tensors matches the dimension C of the 4D Tensors.
+//
+// Arguments:
+//	x: A 4D Tensor for input data.
+//	scale: A 1D Tensor for scaling factor, to scale the normalized x.
+//	offset: A 1D Tensor for offset, to shift to the normalized x.
+//	mean: A 1D Tensor for population mean. Used for inference only;
+// must be empty for training.
+//	variance: A 1D Tensor for population variance. Used for inference only;
+// must be empty for training.
+//
+// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
+// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
+// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
+// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
+// in the cuDNN case), to be reused in the gradient computation.
+func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "FusedBatchNorm",
+		Input: []tf.Input{
+			x, scale, offset, mean, variance,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
+// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
+type RandomStandardNormalAttr func(optionalAttr)
+
+// RandomStandardNormalSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed.  Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
+	return func(m optionalAttr) {
+		m["seed"] = value
+	}
+}
+
+// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
+	return func(m optionalAttr) {
+		m["seed2"] = value
+	}
+}
+
+// Outputs random values from a normal distribution.
+//
+// The generated values will have mean 0 and standard deviation 1.
+//
+// Arguments:
+//	shape: The shape of the output tensor.
+//	dtype: The type of the output.
+//
+// Returns A tensor of the specified shape filled with random normal values.
+func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"dtype": dtype}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "RandomStandardNormal",
+		Input: []tf.Input{
+			shape,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
 type ResourceApplyFtrlAttr func(optionalAttr)
 
@@ -12357,235 +12720,6 @@
 	return values
 }
 
-// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
-type ResourceSparseApplyRMSPropAttr func(optionalAttr)
-
-// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
-	return func(m optionalAttr) {
-		m["use_locking"] = value
-	}
-}
-
-// Update '*var' according to the RMSProp algorithm.
-//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
-//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
-//
-// Arguments:
-//	var_: Should be from a Variable().
-//	ms: Should be from a Variable().
-//	mom: Should be from a Variable().
-//	lr: Scaling factor. Must be a scalar.
-//	rho: Decay rate. Must be a scalar.
-//
-//	epsilon: Ridge term. Must be a scalar.
-//	grad: The gradient.
-//	indices: A vector of indices into the first dimension of var, ms and mom.
-//
-// Returns the created operation.
-func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "ResourceSparseApplyRMSProp",
-		Input: []tf.Input{
-			var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
-		},
-		Attrs: attrs,
-	}
-	return scope.AddOperation(opspec)
-}
-
-// Returns the truth value of (x > y) element-wise.
-//
-// *NOTE*: `Greater` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Greater",
-		Input: []tf.Input{
-			x, y,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
-type SampleDistortedBoundingBoxAttr func(optionalAttr)
-
-// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to non-zero, the random number
-// generator is seeded by the given `seed`.  Otherwise, it is seeded by a random
-// seed.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["seed"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["seed2"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
-//
-// value: The cropped area of the image must contain at least this
-// fraction of any bounding box supplied. The value of this parameter should be
-// non-negative. In the case of 0, the cropped area does not need to overlap
-// any of the bounding boxes supplied.
-// If not specified, defaults to 0.1
-func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["min_object_covered"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
-//
-// value: The cropped area of the image must have an aspect ratio =
-// width / height within this range.
-// If not specified, defaults to <f:0.75 f:1.33 >
-func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["aspect_ratio_range"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
-//
-// value: The cropped area of the image must contain a fraction of the
-// supplied image within this range.
-// If not specified, defaults to <f:0.05 f:1 >
-func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["area_range"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
-//
-// value: Number of attempts at generating a cropped region of the image
-// of the specified constraints. After `max_attempts` failures, return the entire
-// image.
-// If not specified, defaults to 100
-func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["max_attempts"] = value
-	}
-}
-
-// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
-//
-// value: Controls behavior if no bounding boxes supplied.
-// If true, assume an implicit bounding box covering the whole input. If false,
-// raise an error.
-// If not specified, defaults to false
-func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
-	return func(m optionalAttr) {
-		m["use_image_if_no_bounding_boxes"] = value
-	}
-}
-
-// Generate a single randomly distorted bounding box for an image.
-//
-// Bounding box annotations are often supplied in addition to ground-truth labels
-// in image recognition or object localization tasks. A common technique for
-// training such a system is to randomly distort an image while preserving
-// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
-// localization of an object, i.e. bounding box, given an `image_size`,
-// `bounding_boxes` and a series of constraints.
-//
-// The output of this Op is a single bounding box that may be used to crop the
-// original image. The output is returned as 3 tensors: `begin`, `size` and
-// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
-// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
-// what the bounding box looks like.
-//
-// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example,
-//
-// ```python
-//     # Generate a single distorted bounding box.
-//     begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
-//         tf.shape(image),
-//         bounding_boxes=bounding_boxes)
-//
-//     # Draw the bounding box in an image summary.
-//     image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
-//                                                   bbox_for_draw)
-//     tf.summary.image('images_with_box', image_with_box)
-//
-//     # Employ the bounding box to distort the image.
-//     distorted_image = tf.slice(image, begin, size)
-// ```
-//
-// Note that if no bounding box information is available, setting
-// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
-// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
-// false and no bounding boxes are supplied, an error is raised.
-//
-// Arguments:
-//	image_size: 1-D, containing `[height, width, channels]`.
-//	bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
-// associated with the image.
-//
-// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
-// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
-// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
-// Provide as input to `tf.image.draw_bounding_boxes`.
-func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "SampleDistortedBoundingBox",
-		Input: []tf.Input{
-			image_size, bounding_boxes,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2)
-}
-
 // LRNAttr is an optional argument to LRN.
 type LRNAttr func(optionalAttr)
 
@@ -13788,34 +13922,6 @@
 	return op.Output(0), op.Output(1)
 }
 
-// Fast Fourier transform.
-//
-// Computes the 1-dimensional discrete Fourier transform over the inner-most
-// dimension of `input`.
-//
-// Arguments:
-//	input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most
-//   dimension of `input` is replaced with its 1D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.fft
-// @end_compatibility
-func FFT(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "FFT",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Transforms a serialized tensorflow.TensorProto proto into a Tensor.
 //
 // Arguments:
@@ -14396,6 +14502,47 @@
 	return op.Output(0)
 }
 
+// RandomPoissonAttr is an optional argument to RandomPoisson.
+type RandomPoissonAttr func(optionalAttr)
+
+// RandomPoissonSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed(value int64) RandomPoissonAttr {
+	return func(m optionalAttr) {
+		m["seed"] = value
+	}
+}
+
+// RandomPoissonSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed2(value int64) RandomPoissonAttr {
+	return func(m optionalAttr) {
+		m["seed2"] = value
+	}
+}
+
+// Use RandomPoissonV2 instead.
+//
+// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
+func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "RandomPoisson",
+		Input: []tf.Input{
+			shape, rate,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
 type LogUniformCandidateSamplerAttr func(optionalAttr)
 
@@ -16136,148 +16283,6 @@
 	return scope.AddOperation(opspec)
 }
 
-// Computes sigmoid of `x` element-wise.
-//
-// Specifically, `y = 1 / (1 + exp(-x))`.
-func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "Sigmoid",
-		Input: []tf.Input{
-			x,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
-type FusedBatchNormAttr func(optionalAttr)
-
-// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
-//
-// value: A small float number added to the variance of x.
-// If not specified, defaults to 0.0001
-func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
-	return func(m optionalAttr) {
-		m["epsilon"] = value
-	}
-}
-
-// FusedBatchNormDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
-// If not specified, defaults to "NHWC"
-func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
-	return func(m optionalAttr) {
-		m["data_format"] = value
-	}
-}
-
-// FusedBatchNormIsTraining sets the optional is_training attribute to value.
-//
-// value: A bool value to indicate the operation is for training (default)
-// or inference.
-// If not specified, defaults to true
-func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
-	return func(m optionalAttr) {
-		m["is_training"] = value
-	}
-}
-
-// Batch normalization.
-//
-// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-// The size of 1D Tensors matches the dimension C of the 4D Tensors.
-//
-// Arguments:
-//	x: A 4D Tensor for input data.
-//	scale: A 1D Tensor for scaling factor, to scale the normalized x.
-//	offset: A 1D Tensor for offset, to shift to the normalized x.
-//	mean: A 1D Tensor for population mean. Used for inference only;
-// must be empty for training.
-//	variance: A 1D Tensor for population variance. Used for inference only;
-// must be empty for training.
-//
-// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
-// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
-// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
-// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
-// in the cuDNN case), to be reused in the gradient computation.
-func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "FusedBatchNorm",
-		Input: []tf.Input{
-			x, scale, offset, mean, variance,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
-// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
-type RandomStandardNormalAttr func(optionalAttr)
-
-// RandomStandardNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed.  Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
-	return func(m optionalAttr) {
-		m["seed"] = value
-	}
-}
-
-// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
-	return func(m optionalAttr) {
-		m["seed2"] = value
-	}
-}
-
-// Outputs random values from a normal distribution.
-//
-// The generated values will have mean 0 and standard deviation 1.
-//
-// Arguments:
-//	shape: The shape of the output tensor.
-//	dtype: The type of the output.
-//
-// Returns A tensor of the specified shape filled with random normal values.
-func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"dtype": dtype}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "RandomStandardNormal",
-		Input: []tf.Input{
-			shape,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Component-wise divides a SparseTensor by a dense Tensor.
 //
 // *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
@@ -17427,26 +17432,6 @@
 	return op.Output(0)
 }
 
-// Serializes the tree ensemble to a proto.
-//
-// Arguments:
-//	tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
-func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "BoostedTreesSerializeEnsemble",
-		Input: []tf.Input{
-			tree_ensemble_handle,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1)
-}
-
 // StageSizeAttr is an optional argument to StageSize.
 type StageSizeAttr func(optionalAttr)
 
@@ -20376,6 +20361,58 @@
 	return op.Output(0)
 }
 
+// Computes the mean along sparse segments of a tensor.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// Arguments:
+//
+//	indices: A 1-D tensor. Has same rank as `segment_ids`.
+//	segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SparseSegmentMean",
+		Input: []tf.Input{
+			data, indices, segment_ids,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Pop the element at the top of the stack.
+//
+// Arguments:
+//	handle: The handle to a stack.
+//	elem_type: The type of the elem that is popped.
+//
+// Returns The tensor that is popped from the top of the stack.
+func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"elem_type": elem_type}
+	opspec := tf.OpSpec{
+		Type: "StackPopV2",
+		Input: []tf.Input{
+			handle,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Computes hyperbolic cosine of x element-wise.
 func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
 	if scope.Err() != nil {
@@ -26601,36 +26638,6 @@
 	return op.Output(0)
 }
 
-// Debugging/model interpretability outputs for each example.
-//
-// It traverses all the trees and computes debug metrics for individual examples,
-// such as getting split feature ids and logits after each split along the decision
-// path used to compute directional feature contributions.
-//
-// Arguments:
-//
-//	bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-//	logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
-// examples_debug_outputs_serialized.
-//
-// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
-func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"logits_dimension": logits_dimension}
-	opspec := tf.OpSpec{
-		Type: "BoostedTreesExampleDebugOutputs",
-		Input: []tf.Input{
-			tree_ensemble_handle, tf.OutputList(bucketized_features),
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Adds a value to the current value of a variable.
 //
 // Any ReadVariableOp with a control dependency on this op is guaranteed to
@@ -28118,6 +28125,34 @@
 	return op.Output(0)
 }
 
+// Fast Fourier transform.
+//
+// Computes the 1-dimensional discrete Fourier transform over the inner-most
+// dimension of `input`.
+//
+// Arguments:
+//	input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most
+//   dimension of `input` is replaced with its 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.fft
+// @end_compatibility
+func FFT(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "FFT",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Performs a padding as a preprocess during a convolution.
 //
 // Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -31743,54 +31778,6 @@
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
-	return func(m optionalAttr) {
-		m["container"] = value
-	}
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
-	return func(m optionalAttr) {
-		m["shared_name"] = value
-	}
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue.  The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "WholeFileReaderV2",
-
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Transforms a tf.Example proto (as a string) into typed tensors.
 //
 // Arguments:
@@ -31861,6 +31848,54 @@
 	return sparse_indices, sparse_values, sparse_shapes, dense_values
 }
 
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+	return func(m optionalAttr) {
+		m["container"] = value
+	}
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+	return func(m optionalAttr) {
+		m["shared_name"] = value
+	}
+}
+
+// A Reader that outputs the entire contents of a file as a value.
+//
+// To use, enqueue filenames in a Queue.  The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
+//
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "WholeFileReaderV2",
+
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Deserializes a serialized tree ensemble config and replaces current tree
 //
 // ensemble.
@@ -31883,38 +31918,3 @@
 	}
 	return scope.AddOperation(opspec)
 }
-
-// Runs multiple additive regression ensemble predictors on input instances and
-//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
-//
-// Arguments:
-//
-//	cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-//	cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-//	bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-//	logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"logits_dimension": logits_dimension}
-	opspec := tf.OpSpec{
-		Type: "BoostedTreesTrainingPredict",
-		Input: []tf.Input{
-			tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0), op.Output(1), op.Output(2)
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5af6437..2dc2808 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -78,6 +78,7 @@
         "//tensorflow:__pkg__",
         "//tensorflow/python/tools:__pkg__",
         "//tensorflow/python/tools/api/generator:__pkg__",
+        "//tensorflow/tools/api/tests:__pkg__",
     ],
     deps = [
         ":array_ops",
@@ -2090,6 +2091,18 @@
     srcs = [
         "ops/custom_gradient.py",
         "ops/gradients.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":gradients_impl",
+        "//tensorflow/python/eager:function",
+        "//tensorflow/python/eager:tape",
+    ],
+)
+
+py_library(
+    name = "gradients_impl",
+    srcs = [
         "ops/gradients_impl.py",
     ],
     srcs_version = "PY2AND3",
@@ -3045,6 +3058,7 @@
         ":functional_ops",
         ":gradients",
         ":layers",
+        ":list_ops",
         ":math_grad",
         ":math_ops",
         ":nn_grad",
@@ -4381,6 +4395,7 @@
         "training/ftrl_test.py",
         "training/gradient_descent_test.py",
         "training/learning_rate_decay_test.py",
+        "training/learning_rate_decay_v2_test.py",
         "training/momentum_test.py",
         "training/optimizer_test.py",
         "training/proximal_adagrad_test.py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index a2ab63b..4921ecc 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -48,6 +48,13 @@
 
 from tensorflow.python import pywrap_tensorflow
 
+from tensorflow.python.tools import component_api_helper
+component_api_helper.package_hook(
+    parent_package_str='tensorflow.python',
+    child_package_str=(
+        'tensorflow_estimator.python.estimator'))
+del component_api_helper
+
 # Protocol buffers
 from tensorflow.core.framework.graph_pb2 import *
 from tensorflow.core.framework.node_def_pb2 import *
diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD
new file mode 100644
index 0000000..3289b44
--- /dev/null
+++ b/tensorflow/python/autograph/BUILD
@@ -0,0 +1,31 @@
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+    name = "autograph",
+    srcs = [
+        "__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/python:util",
+        "//tensorflow/python/autograph/impl",
+        "//tensorflow/python/autograph/lang",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/utils",
+    ],
+)
diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
similarity index 91%
rename from tensorflow/contrib/autograph/CONTRIBUTING.md
rename to tensorflow/python/autograph/CONTRIBUTING.md
index 06fb7b0..1ded5ba 100644
--- a/tensorflow/contrib/autograph/CONTRIBUTING.md
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -2,6 +2,15 @@
 
 We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
 
+### Note to active contributors
+
+In preparation for TF 2.0, we moved the code base of AutoGraph from
+`tensorflow/contrib/autograph` to `tensorflow/python/autograph`. The move
+does not impact functionality, and AutoGraph will remain accessible under
+`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
+
+When 
+
 ## TensorFlow Code of Conduct
 Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
 
diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/python/autograph/LIMITATIONS.md
similarity index 100%
rename from tensorflow/contrib/autograph/LIMITATIONS.md
rename to tensorflow/python/autograph/LIMITATIONS.md
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
new file mode 100644
index 0000000..cc54da4
--- /dev/null
+++ b/tensorflow/python/autograph/README.md
@@ -0,0 +1,143 @@
+# AutoGraph
+
+IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+
+AutoGraph is a Python to TensorFlow compiler.
+
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops.  [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
+
+For example, this Python function:
+
+```
+def f(x):
+  if x < 0:
+    x = -x
+  return x
+```
+
+would be converted to this:
+
+```
+def graph_mode_f(x):
+  with tf.name_scope('f'):
+
+    def if_true():
+      with tf.name_scope('if_true'):
+        x_1, = x,
+        x_1 = tf.negative(x_1)
+        return x_1,
+
+    def if_false():
+      with tf.name_scope('if_false'):
+        x_1, = x,
+        return x_1,
+    x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
+    return x
+```
+
+so you can use it like an op:
+
+```
+with tf.Graph().as_default():
+  x = tf.constant(-1.0)
+
+  converted_f = autograph.to_graph(f)
+  y = converted_f(x)
+
+  with tf.Session() as sess:
+    print(sess.run(y))
+    # Output: 1
+```
+
+# Getting started
+
+Use AutoGraph in one of the following ways, described below:
+
+ 1. Annotations (simpler)
+ 2. Functional API (more flexible)
+
+To get started, install the latest nightly TensorFlow build:
+
+```shell
+pip install -U tf-nightly
+```
+
+Then import the `autograph` module from `tf.contrib`:
+
+```
+from tensorflow.contrib import autograph as ag
+```
+
+### Related links
+
+Articles:
+
+ * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
+
+Interactive notebooks:
+
+ * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
+ * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
+ * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
+ * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
+ * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
+ * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
+ * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
+
+## Using with annotations
+
+Annotating a function or class with `@convert` converts it in place:
+
+```
+@ag.convert()
+def f(x):
+  if x < 0:
+    x = -x
+  return x
+```
+
+... so that it always outputs TensorFlow code:
+
+```
+with tf.Graph().as_default():
+  x = tf.constant(-1)
+
+  y = f(x)
+
+  with tf.Session() as sess:
+    print(sess.run(y))
+    # Output: 1
+```
+
+## Using the functional API
+
+The functional API allows you to convert an existing function, class or object after it was defined:
+
+```
+converted_f = ag.to_graph(f)
+
+print(converted_f(tf.constant(-1)))
+# Output: Tensor
+
+print(f(-1))
+# Output: 1
+```
+
+You can use the functional API to inspect the generated code as well:
+
+```
+print(ag.to_code(f))
+# Output: <Python and TensorFlow code>
+```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/python/autograph/STYLE_GUIDE.md
similarity index 100%
rename from tensorflow/contrib/autograph/STYLE_GUIDE.md
rename to tensorflow/python/autograph/STYLE_GUIDE.md
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
new file mode 100644
index 0000000..c3448e6
--- /dev/null
+++ b/tensorflow/python/autograph/__init__.py
@@ -0,0 +1,68 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Autograph compiles Python code into equivalent TensorFlow code.
+
+Equivalent here means that they have the same effect when executed.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Bring only the relevant symbols to the top level.
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core.errors import GraphConstructionError
+from tensorflow.python.autograph.core.errors import TfRuntimeError
+from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import RunMode
+from tensorflow.python.autograph.impl.api import convert
+from tensorflow.python.autograph.impl.api import converted_call
+from tensorflow.python.autograph.impl.api import do_not_convert
+from tensorflow.python.autograph.impl.api import to_code
+from tensorflow.python.autograph.impl.api import to_graph
+from tensorflow.python.autograph.lang.directives import set_element_type
+from tensorflow.python.autograph.lang.directives import set_loop_options
+from tensorflow.python.autograph.lang.special_functions import stack
+from tensorflow.python.autograph.lang.special_functions import tensor_list
+from tensorflow.python.autograph.pyct.transformer import AutographParseError
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+    # Main API
+    'RunMode',
+    'convert',
+    'converted_call',
+    'do_not_convert',
+    'to_code',
+    'to_graph',
+    # Overloaded operators
+    'operators',
+    # Errors
+    'improved_errors',
+    'GraphConstructionError',
+    'TfRuntimeError',
+    # Python language "extensions"
+    'set_element_type',
+    'set_loop_options',
+    'stack',
+    'tensor_list',
+    # Exceptions
+    'AutographParseError',
+    # Utilities: to be removed
+    'utils',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
similarity index 76%
rename from tensorflow/contrib/autograph/converters/BUILD
rename to tensorflow/python/autograph/converters/BUILD
index 2d2ab70..7b029de 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -38,11 +38,11 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/core",
-        "//tensorflow/contrib/autograph/lang",
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/pyct/static_analysis",
         "//tensorflow/python:util",
+        "//tensorflow/python/autograph/core",
+        "//tensorflow/python/autograph/lang",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/pyct/static_analysis",
         "@gast_archive//:gast",
     ],
 )
@@ -54,8 +54,8 @@
     tags = ["no_windows"],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -65,8 +65,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -77,8 +77,8 @@
     tags = ["no_windows"],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -90,9 +90,9 @@
     tags = ["no_windows"],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/impl",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/impl",
     ],
 )
 
@@ -102,8 +102,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -113,8 +113,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -124,8 +124,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -139,8 +139,8 @@
     ],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -150,9 +150,9 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/lang",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/lang",
     ],
 )
 
@@ -161,9 +161,9 @@
     srcs = ["name_scopes_test.py"],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -173,8 +173,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -184,8 +184,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -195,8 +195,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -207,8 +207,8 @@
     tags = ["notsan"],
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
     ],
 )
 
@@ -218,9 +218,9 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -230,9 +230,9 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -242,8 +242,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":converters",
-        "//tensorflow/contrib/autograph/core:test_lib",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/core:test_lib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/python/autograph/converters/__init__.py
similarity index 100%
rename from tensorflow/contrib/autograph/converters/__init__.py
rename to tensorflow/python/autograph/converters/__init__.py
diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py
similarity index 93%
rename from tensorflow/contrib/autograph/converters/asserts.py
rename to tensorflow/python/autograph/converters/asserts.py
index af2f20f..56a9753 100644
--- a/tensorflow/contrib/autograph/converters/asserts.py
+++ b/tensorflow/python/autograph/converters/asserts.py
@@ -20,8 +20,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
 
 
 class AssertTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
similarity index 90%
rename from tensorflow/contrib/autograph/converters/asserts_test.py
rename to tensorflow/python/autograph/converters/asserts_test.py
index 38faba4..01282f9 100644
--- a/tensorflow/contrib/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -20,8 +20,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
similarity index 94%
rename from tensorflow/contrib/autograph/converters/break_statements.py
rename to tensorflow/python/autograph/converters/break_statements.py
index 1807796..bd6b0b2 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 class _Break(object):
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/break_statements_test.py
rename to tensorflow/python/autograph/converters/break_statements_test.py
index fcae7d6..39406a9 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.eager import context as tfe_ctx
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
new file mode 100644
index 0000000..b8b268d
--- /dev/null
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -0,0 +1,65 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles builtins and other special functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+
+
+class BuiltinFunctionTransformer(converter.Base):
+  """Handles builtin functions.
+
+  This transformer only covers functions that are translated into a
+  TF equivalent, like `len`.
+  """
+
+  def _convert_builtin(self, f, args, as_expression):
+    template = """
+      ag__.func(args)
+    """
+    if as_expression:
+      return templates.replace_as_expression(
+          template, func=py_builtins.overload_of(f).__name__, args=args)
+    else:
+      return templates.replace(
+          template, func=py_builtins.overload_of(f).__name__, args=args)
+
+  def visit_Call(self, node):
+    node = self.generic_visit(node)
+    if anno.hasanno(node.func, 'live_val'):
+      live_val = anno.getanno(node.func, 'live_val')
+      if live_val in py_builtins.SUPPORTED_BUILTINS:
+        node = self._convert_builtin(live_val, node.args, as_expression=True)
+    return node
+
+  def visit_Print(self, node):
+    node = self.generic_visit(node)
+    args = node.values
+    # Following is the case when calling print(a, b)
+    if len(args) == 1 and isinstance(args[0], gast.Tuple):
+      args = args[0].elts
+    return self._convert_builtin(print, args, as_expression=False)
+
+
+def transform(node, ctx):
+  return BuiltinFunctionTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
similarity index 84%
rename from tensorflow/contrib/autograph/converters/builtin_functions_test.py
rename to tensorflow/python/autograph/converters/builtin_functions_test.py
index d0a0cbb..c87c304 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -20,9 +20,10 @@
 
 import six
 
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
@@ -34,11 +35,11 @@
     def test_fn(a):
       return len(a)
 
-    with self.converted(test_fn, builtin_functions, {'len': len},
-                        array_ops.shape) as result:
+    with self.converted(test_fn, builtin_functions, {'len': len}) as result:
       with self.cached_session() as sess:
-        ops = result.test_fn(constant_op.constant([0, 0, 0]))
-        self.assertEqual(sess.run(ops), 3)
+        p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+        ops = result.test_fn(p)
+        self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
 
   def test_print(self):
 
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
similarity index 97%
rename from tensorflow/contrib/autograph/converters/call_trees.py
rename to tensorflow/python/autograph/converters/call_trees.py
index 2d1bed3..6a606c4 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -26,12 +26,12 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/call_trees_test.py
rename to tensorflow/python/autograph/converters/call_trees_test.py
index ca4d1f2..0e50f42 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -20,8 +20,8 @@
 
 import numpy as np
 
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
similarity index 94%
rename from tensorflow/contrib/autograph/converters/conditional_expressions.py
rename to tensorflow/python/autograph/converters/conditional_expressions.py
index 63f649d..40728f5 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 class _FunctionDefs(object):
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py b/tensorflow/python/autograph/converters/conditional_expressions_test.py
similarity index 91%
rename from tensorflow/contrib/autograph/converters/conditional_expressions_test.py
rename to tensorflow/python/autograph/converters/conditional_expressions_test.py
index 95a3108..dd1f8d4 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/python/autograph/converters/continue_statements.py
similarity index 95%
rename from tensorflow/contrib/autograph/converters/continue_statements.py
rename to tensorflow/python/autograph/converters/continue_statements.py
index 0476e97..584cdc1 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/python/autograph/converters/continue_statements.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 # Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py
similarity index 94%
rename from tensorflow/contrib/autograph/converters/continue_statements_test.py
rename to tensorflow/python/autograph/converters/continue_statements_test.py
index 37c1521..d6aaa50 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/python/autograph/converters/continue_statements_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.eager import context as tfe_ctx
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/control_flow.py
rename to tensorflow/python/autograph/converters/control_flow.py
index 3530fbb..416a60d 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -20,12 +20,12 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis import annos
 
 
 class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/control_flow_test.py
rename to tensorflow/python/autograph/converters/control_flow_test.py
index 1d04ba3..cfa0ea9 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -18,9 +18,9 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/python/autograph/converters/decorators.py
similarity index 97%
rename from tensorflow/contrib/autograph/converters/decorators.py
rename to tensorflow/python/autograph/converters/decorators.py
index 3471bd1..724f0fe 100644
--- a/tensorflow/contrib/autograph/converters/decorators.py
+++ b/tensorflow/python/autograph/converters/decorators.py
@@ -24,8 +24,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/python/autograph/converters/decorators_test.py
similarity index 88%
rename from tensorflow/contrib/autograph/converters/decorators_test.py
rename to tensorflow/python/autograph/converters/decorators_test.py
index 095abc5..fb31c8d 100644
--- a/tensorflow/contrib/autograph/converters/decorators_test.py
+++ b/tensorflow/python/autograph/converters/decorators_test.py
@@ -19,11 +19,13 @@
 from __future__ import print_function
 
 from functools import wraps
+import imp
 
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python import autograph
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.platform import test
 
 
@@ -136,6 +138,12 @@
 
       return inner_fn(a)
 
+    # Work around TensorFlow's symbol suppression mechanism that causes core to
+    # be invisible in the generated code.
+    core_mod = imp.new_module('core')
+    core_mod.converter_testing = converter_testing
+    autograph.core = core_mod
+
     # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
     self.assertEqual(14, test_fn(1))
 
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/directives.py
rename to tensorflow/python/autograph/converters/directives.py
index 77f625b..fc64634 100644
--- a/tensorflow/contrib/autograph/converters/directives.py
+++ b/tensorflow/python/autograph/converters/directives.py
@@ -25,9 +25,9 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.util import tf_inspect
 
 ENCLOSING_LOOP = 'enclosing_loop'
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py
similarity index 88%
rename from tensorflow/contrib/autograph/converters/directives_test.py
rename to tensorflow/python/autograph/converters/directives_test.py
index a2d083b..570fb8e 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/python/autograph/converters/directives_test.py
@@ -18,12 +18,12 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import directives as directives_converter
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core.converter import AgAnno
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import directives as directives_converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core.converter import AgAnno
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/python/autograph/converters/error_handlers.py
similarity index 91%
rename from tensorflow/contrib/autograph/converters/error_handlers.py
rename to tensorflow/python/autograph/converters/error_handlers.py
index 1936821..de46c0c 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers.py
+++ b/tensorflow/python/autograph/converters/error_handlers.py
@@ -22,9 +22,9 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
 
 
 class ErrorRewritingTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/python/autograph/converters/error_handlers_test.py
similarity index 85%
rename from tensorflow/contrib/autograph/converters/error_handlers_test.py
rename to tensorflow/python/autograph/converters/error_handlers_test.py
index 5d61b22..676ff9e 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/python/autograph/converters/error_handlers_test.py
@@ -18,11 +18,11 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py
similarity index 94%
rename from tensorflow/contrib/autograph/converters/list_comprehensions.py
rename to tensorflow/python/autograph/converters/list_comprehensions.py
index ecf4628..5be6cb9 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions.py
@@ -32,8 +32,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
 
 
 # TODO(mdan): This should covert directly to operator calls.
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py
similarity index 92%
rename from tensorflow/contrib/autograph/converters/list_comprehensions_test.py
rename to tensorflow/python/autograph/converters/list_comprehensions_test.py
index 59b5ce9..1e66139 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import list_comprehensions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import list_comprehensions
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/python/autograph/converters/lists.py
similarity index 95%
rename from tensorflow/contrib/autograph/converters/lists.py
rename to tensorflow/python/autograph/converters/lists.py
index a02fc82..8180801 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/python/autograph/converters/lists.py
@@ -32,12 +32,12 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 # Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
similarity index 91%
rename from tensorflow/contrib/autograph/converters/lists_test.py
rename to tensorflow/python/autograph/converters/lists_test.py
index c5e2dcf..f6da845 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -18,12 +18,12 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.lang import special_functions
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.lang import special_functions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
similarity index 86%
rename from tensorflow/contrib/autograph/converters/logical_expressions.py
rename to tensorflow/python/autograph/converters/logical_expressions.py
index 16eb1f0..8c4d53f 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -23,10 +23,10 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
 
 
 # TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
@@ -57,8 +57,6 @@
         gast.NotEq: 'tf.not_equal',
         gast.Or: 'tf.logical_or',
         gast.USub: 'tf.negative',
-        gast.Is: 'autograph_utils.dynamic_is',
-        gast.IsNot: 'autograph_utils.dynamic_is_not'
     }
 
   def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@
         '"a.x or b"; for a workaround, assign the expression to a local '
         'variable and use that instead, for example "tmp = a.x", "tmp or b"')
 
+  def _has_matching_func(self, operator):
+    op_type = type(operator)
+    return op_type in self.op_mapping
+
   def _matching_func(self, operator):
     op_type = type(operator)
-    mapped_op = self.op_mapping.get(op_type)
-    if not mapped_op:
-      raise NotImplementedError('operator %s is not yet supported' % op_type)
-    return mapped_op
+    return self.op_mapping[op_type]
 
   def _as_function(self, func_name, args):
     template = """
@@ -90,6 +89,16 @@
 
   def visit_Compare(self, node):
     node = self.generic_visit(node)
+
+    if not all(self._has_matching_func(op) for op in node.ops):
+      if len(node.ops) == 1:
+        # Basic expressions are safe to leave as they are.
+        return node
+      else:
+        raise NotImplementedError(
+            'compound expression with at least one unsupported '
+            'operator: {}'.format(node.ops))
+
     ops_and_comps = list(zip(node.ops, node.comparators))
     left = node.left
     op_tree = None
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
similarity index 83%
rename from tensorflow/contrib/autograph/converters/logical_expressions_test.py
rename to tensorflow/python/autograph/converters/logical_expressions_test.py
index 8f9eee7..b78b4d3 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
@@ -47,6 +47,13 @@
       with self.cached_session() as sess:
         self.assertTrue(sess.run(result.test_fn(True, False, True)))
 
+  def test_unsupported_ops(self):
+    def test_fn(a, b):
+      return a in b
+
+    with self.converted(test_fn, logical_expressions, {}) as result:
+      self.assertTrue(result.test_fn('a', ('a',)))
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/name_scopes.py
similarity index 95%
rename from tensorflow/contrib/autograph/converters/name_scopes.py
rename to tensorflow/python/autograph/converters/name_scopes.py
index dd6c6bf..a9c55cc 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/name_scopes.py
@@ -20,8 +20,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
 
 
 class FunctionNameScopeTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/name_scopes_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/converters/name_scopes_test.py
rename to tensorflow/python/autograph/converters/name_scopes_test.py
index a329b0d..73933c1 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/name_scopes_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
similarity index 97%
rename from tensorflow/contrib/autograph/converters/return_statements.py
rename to tensorflow/python/autograph/converters/return_statements.py
index a351cd8..62da045 100644
--- a/tensorflow/contrib/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -20,11 +20,11 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 # TODO(mdan): Move this logic into transformer_base.
diff --git a/tensorflow/contrib/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/converters/return_statements_test.py
rename to tensorflow/python/autograph/converters/return_statements_test.py
index 3c7c8c8..01dd03d 100644
--- a/tensorflow/contrib/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import test
 
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
similarity index 94%
rename from tensorflow/contrib/autograph/converters/side_effect_guards.py
rename to tensorflow/python/autograph/converters/side_effect_guards.py
index b808604..6e48e57 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -36,12 +36,12 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 
 class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
similarity index 97%
rename from tensorflow/contrib/autograph/converters/side_effect_guards_test.py
rename to tensorflow/python/autograph/converters/side_effect_guards_test.py
index 5fe5114..cef3199 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.core import converter_testing
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/python/autograph/converters/slices.py
similarity index 93%
rename from tensorflow/contrib/autograph/converters/slices.py
rename to tensorflow/python/autograph/converters/slices.py
index c527f98..11cea6d 100644
--- a/tensorflow/contrib/autograph/converters/slices.py
+++ b/tensorflow/python/autograph/converters/slices.py
@@ -20,9 +20,9 @@
 
 import gast
 
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import templates
 
 
 class SliceTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py
similarity index 87%
rename from tensorflow/contrib/autograph/converters/slices_test.py
rename to tensorflow/python/autograph/converters/slices_test.py
index d74b2e0..e190a7c 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/python/autograph/converters/slices_test.py
@@ -18,12 +18,12 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import list_ops
diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
similarity index 78%
rename from tensorflow/contrib/autograph/core/BUILD
rename to tensorflow/python/autograph/core/BUILD
index 1873045..85fecf0 100644
--- a/tensorflow/contrib/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -25,9 +25,9 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/pyct/static_analysis",
-        "//tensorflow/contrib/autograph/utils",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/pyct/static_analysis",
+        "//tensorflow/python/autograph/utils",
     ],
 )
 
@@ -65,10 +65,10 @@
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
         ":core",
-        "//tensorflow/contrib/autograph/operators",
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/pyct/static_analysis",
-        "//tensorflow/contrib/autograph/utils",
+        "//tensorflow/python/autograph/operators",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/pyct/static_analysis",
+        "//tensorflow/python/autograph/utils",
         "@gast_archive//:gast",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/contrib/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
similarity index 92%
rename from tensorflow/contrib/autograph/core/config.py
rename to tensorflow/python/autograph/core/config.py
index 878bb7e..4fa8489 100644
--- a/tensorflow/contrib/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph import utils
+from tensorflow.python.autograph import utils
 
 
 PYTHON_LITERALS = {
@@ -36,7 +36,7 @@
     # have well-known names. Not referring to the module directly to avoid
     # circular imports.
     (
-        utils.__name__[:-len('.contrib.autograph.utils')],),
+        utils.__name__[:-len('.python.autograph.utils')],),
 ))
 
 NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
similarity index 92%
rename from tensorflow/contrib/autograph/core/converter.py
rename to tensorflow/python/autograph/core/converter.py
index 83a80c1..7b3905f 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -67,19 +67,19 @@
 from enum import Enum
 
 
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import naming
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import naming
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
 
 # TODO(mdan): These contexts can be refactored into first class objects.
 # For example, we could define Program and Entity abstractions that hold on
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
similarity index 90%
rename from tensorflow/contrib/autograph/core/converter_testing.py
rename to tensorflow/python/autograph/core/converter_testing.py
index 5ee2c3f..0a0c6f9 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -24,15 +24,15 @@
 
 import six
 
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import pretty_printer
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
similarity index 99%
rename from tensorflow/contrib/autograph/core/errors.py
rename to tensorflow/python/autograph/core/errors.py
index 5a57d57..0750353 100644
--- a/tensorflow/contrib/autograph/core/errors.py
+++ b/tensorflow/python/autograph/core/errors.py
@@ -31,7 +31,7 @@
 import sys
 import traceback
 
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.framework import errors_impl
 
 # TODO(mdan): Add a superclass common to all errors.
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/core/errors_test.py
rename to tensorflow/python/autograph/core/errors_test.py
index 404c1f5..0444ed7 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import origin_info
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors as tf_errors
 from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py
similarity index 98%
rename from tensorflow/contrib/autograph/core/naming.py
rename to tensorflow/python/autograph/core/naming.py
index b1d3f76..aecc9e3 100644
--- a/tensorflow/contrib/autograph/core/naming.py
+++ b/tensorflow/python/autograph/core/naming.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import qual_names
 
 
 class Namer(object):
diff --git a/tensorflow/contrib/autograph/core/naming_test.py b/tensorflow/python/autograph/core/naming_test.py
similarity index 97%
rename from tensorflow/contrib/autograph/core/naming_test.py
rename to tensorflow/python/autograph/core/naming_test.py
index d2bebd0..2db9883 100644
--- a/tensorflow/contrib/autograph/core/naming_test.py
+++ b/tensorflow/python/autograph/core/naming_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.core import naming
+from tensorflow.python.autograph.core import naming
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
similarity index 90%
rename from tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
rename to tensorflow/python/autograph/docs/pyfunc_dtypes.md
index bcbb920..c2427f5 100644
--- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
+++ b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
@@ -4,7 +4,7 @@
 [data type](https://www.tensorflow.org/guide/tensors#data_types).
 
 When wrapping a function with `py_func`, for instance using
-`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two
+`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
 options to specify the returned data type:
 
  * explicitly, with a specified `tf.DType` value
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD
similarity index 75%
rename from tensorflow/contrib/autograph/impl/BUILD
rename to tensorflow/python/autograph/impl/BUILD
index a543859..bef62a6 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/python/autograph/impl/BUILD
@@ -23,14 +23,14 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/converters",
-        "//tensorflow/contrib/autograph/core",
-        "//tensorflow/contrib/autograph/operators",
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/pyct/static_analysis",
-        "//tensorflow/contrib/autograph/utils",
         "//tensorflow/python:platform",
         "//tensorflow/python:util",
+        "//tensorflow/python/autograph/converters",
+        "//tensorflow/python/autograph/core",
+        "//tensorflow/python/autograph/operators",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/pyct/static_analysis",
+        "//tensorflow/python/autograph/utils",
         "@gast_archive//:gast",
         "@six_archive//:six",
     ],
@@ -43,8 +43,8 @@
     tags = ["no_windows"],
     deps = [
         ":impl",
-        "//tensorflow/contrib/autograph/utils",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/utils",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
similarity index 94%
rename from tensorflow/contrib/autograph/impl/api.py
rename to tensorflow/python/autograph/impl/api.py
index 276a387..669d36b 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -22,17 +22,13 @@
 
 from enum import Enum
 
-# pylint:disable=g-bad-import-order
-import six
-# pylint:enable=g-bad-import-order
-
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import conversion
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import conversion
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import tf_decorator
 from tensorflow.python.util import tf_inspect
@@ -150,7 +146,7 @@
   unknown_arg_value = object()  # Sentinel for arguments of unknown value
 
   if inspect_utils.isbuiltin(f):
-    return builtins.dynamic_builtin(f, *args, **kwargs)
+    return py_builtins.overload_of(f)(*args, **kwargs)
 
   if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
     # Regular functions
@@ -257,7 +253,7 @@
                                                   arg_types)
 
   nodes = []
-  for dep in reversed(program_ctx.dependency_cache.values()):
+  for dep in reversed(tuple(program_ctx.dependency_cache.values())):
     nodes.extend(dep)
   compiled_module, compiled_src = compiler.ast_to_object(
       nodes,
@@ -327,6 +323,6 @@
 
   code = '\n'.join(
       compiler.ast_to_source(dep, indentation)
-      for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
+      for dep in reversed(tuple(program_ctx.dependency_cache.values())))
 
   return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/impl/api_test.py
rename to tensorflow/python/autograph/impl/api_test.py
index 803fde9..54e12f0 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -20,11 +20,11 @@
 
 import numpy as np
 
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
 from tensorflow.python.util import tf_inspect
@@ -38,9 +38,6 @@
   def setUp(self):
     config.COMPILED_IMPORT_STATEMENTS = (
         'from __future__ import print_function',
-        'from tensorflow.contrib.autograph import utils'
-        ' as autograph_utils',
-        'tf = autograph_utils.fake_tf()',
     )
 
   def test_decorator_recurses(self):
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
similarity index 87%
rename from tensorflow/contrib/autograph/impl/conversion.py
rename to tensorflow/python/autograph/impl/conversion.py
index fc8a976..928ff9e 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -22,34 +22,34 @@
 
 import gast
 
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import directives
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.converters import directives
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/impl/conversion_test.py
rename to tensorflow/python/autograph/impl/conversion_test.py
index 8643257..07d0f75 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -20,11 +20,11 @@
 
 import gast
 
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.impl import conversion
 from tensorflow.python.framework import constant_op
 from tensorflow.python.keras.engine import training
 from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD
similarity index 93%
rename from tensorflow/contrib/autograph/lang/BUILD
rename to tensorflow/python/autograph/lang/BUILD
index 77a2184..462349c 100644
--- a/tensorflow/contrib/autograph/lang/BUILD
+++ b/tensorflow/python/autograph/lang/BUILD
@@ -25,7 +25,7 @@
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/operators",
+        "//tensorflow/python/autograph/operators",
     ],
 )
 
diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/python/autograph/lang/directives.py
similarity index 100%
rename from tensorflow/contrib/autograph/lang/directives.py
rename to tensorflow/python/autograph/lang/directives.py
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
similarity index 97%
rename from tensorflow/contrib/autograph/lang/special_functions.py
rename to tensorflow/python/autograph/lang/special_functions.py
index 6149cbb..e4838d1 100644
--- a/tensorflow/contrib/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -23,7 +23,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
 
 
 def tensor_list(elements,
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
similarity index 97%
rename from tensorflow/contrib/autograph/lang/special_functions_test.py
rename to tensorflow/python/autograph/lang/special_functions_test.py
index db492cc..1f1cec1 100644
--- a/tensorflow/contrib/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.python.autograph.lang import special_functions
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_util
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
similarity index 84%
rename from tensorflow/contrib/autograph/operators/BUILD
rename to tensorflow/python/autograph/operators/BUILD
index 332d5da..a116611 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -22,12 +22,12 @@
         "__init__.py",
         "control_flow.py",
         "data_structures.py",
+        "py_builtins.py",
         "slices.py",
     ],
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/utils",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:control_flow_ops",
@@ -37,6 +37,7 @@
         "//tensorflow/python:tensor_array_ops",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python:variables",
+        "//tensorflow/python/autograph/utils",
         "//tensorflow/python/data/ops:dataset_ops",
     ],
 )
@@ -62,6 +63,17 @@
 )
 
 py_test(
+    name = "py_builtins_test",
+    srcs = ["py_builtins_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["no_windows"],
+    deps = [
+        ":operators",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
     name = "slices_test",
     srcs = ["slices_test.py"],
     srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
similarity index 60%
rename from tensorflow/contrib/autograph/operators/__init__.py
rename to tensorflow/python/autograph/operators/__init__.py
index 392cb60..0d3b44b 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -37,14 +37,19 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.operators.control_flow import for_stmt
-from tensorflow.contrib.autograph.operators.control_flow import while_stmt
-from tensorflow.contrib.autograph.operators.data_structures import list_append
-from tensorflow.contrib.autograph.operators.data_structures import list_pop
-from tensorflow.contrib.autograph.operators.data_structures import list_stack
-from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
-from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
-from tensorflow.contrib.autograph.operators.data_structures import new_list
-from tensorflow.contrib.autograph.operators.slices import get_item
-from tensorflow.contrib.autograph.operators.slices import GetItemOpts
-from tensorflow.contrib.autograph.operators.slices import set_item
+from tensorflow.python.autograph.operators.control_flow import for_stmt
+from tensorflow.python.autograph.operators.control_flow import while_stmt
+from tensorflow.python.autograph.operators.data_structures import list_append
+from tensorflow.python.autograph.operators.data_structures import list_pop
+from tensorflow.python.autograph.operators.data_structures import list_stack
+from tensorflow.python.autograph.operators.data_structures import ListPopOpts
+from tensorflow.python.autograph.operators.data_structures import ListStackOpts
+from tensorflow.python.autograph.operators.data_structures import new_list
+from tensorflow.python.autograph.operators.py_builtins import float_
+from tensorflow.python.autograph.operators.py_builtins import int_
+from tensorflow.python.autograph.operators.py_builtins import len_
+from tensorflow.python.autograph.operators.py_builtins import print_
+from tensorflow.python.autograph.operators.py_builtins import range_
+from tensorflow.python.autograph.operators.slices import get_item
+from tensorflow.python.autograph.operators.slices import GetItemOpts
+from tensorflow.python.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
similarity index 97%
rename from tensorflow/contrib/autograph/operators/control_flow.py
rename to tensorflow/python/autograph/operators/control_flow.py
index 9909e52..6eedd69 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.python.autograph.operators import py_builtins
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
@@ -82,8 +82,8 @@
 
 
 def _known_len_for_stmt(iter_, extra_test, body, init_state):
-  """Overload of for_stmt that iterates over objects that define a length."""
-  n = builtins.dynamic_len(iter_)
+  """Overload of for_stmt that iterates over objects that admit a length."""
+  n = py_builtins.len_(iter_)
 
   def while_body(iterate_index, *state):
     iterate = iter_[iterate_index]
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
similarity index 97%
rename from tensorflow/contrib/autograph/operators/control_flow_test.py
rename to tensorflow/python/autograph/operators/control_flow_test.py
index 677b7f8..bb214b6 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.operators import control_flow
+from tensorflow.python.autograph.operators import control_flow
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
similarity index 100%
rename from tensorflow/contrib/autograph/operators/data_structures.py
rename to tensorflow/python/autograph/operators/data_structures.py
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
similarity index 98%
rename from tensorflow/contrib/autograph/operators/data_structures_test.py
rename to tensorflow/python/autograph/operators/data_structures_test.py
index 4b1e835..8532dbe 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/operators/dispatch_context.py b/tensorflow/python/autograph/operators/dispatch_context.py
similarity index 100%
rename from tensorflow/contrib/autograph/operators/dispatch_context.py
rename to tensorflow/python/autograph/operators/dispatch_context.py
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
new file mode 100644
index 0000000..1d37ae7
--- /dev/null
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+  if f in SUPPORTED_BUILTINS:
+    return BUILTIN_FUINCTIONS_MAP[f.__name__]
+  return f
+
+
+def abs_(x):
+  if tensor_util.is_tensor(x):
+    return _tf_abs(x)
+  return _py_abs(x)
+
+
+def _tf_abs(x):
+  return math_ops.abs(x)
+
+
+def _py_abs(x):
+  return abs(x)
+
+
+def float_(x=0):
+  if tensor_util.is_tensor(x):
+    return _tf_float(x)
+  return _py_float(x)
+
+
+def _tf_float(x):
+  # TODO(mdan): We shouldn't assume float32.
+  if x.dtype == dtypes.string:
+    return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+  return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+  return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+  if tensor_util.is_tensor(x):
+    return _tf_int(x, base)
+  return _py_int(x, base)
+
+
+def _tf_int(x, base):
+  if base not in (10, UNDEFINED):
+    raise NotImplementedError('base {} not supported for int'.format(base))
+
+  # TODO(mdan): We shouldn't assume int32.
+  if x.dtype == dtypes.string:
+    return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+  return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+  if base is UNDEFINED:
+    return int(x)
+  return int(x, base)
+
+
+def len_(s):
+  if tensors.is_tensor_array(s):
+    return _tf_tensor_array_len(s)
+  elif tensors.is_tensor_list(s):
+    return _tf_tensor_list_len(s)
+  elif tensor_util.is_tensor(s):
+    return _tf_tensor_len(s)
+  return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+  return s.size()
+
+
+def _tf_tensor_list_len(s):
+  return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+  """Overload of len_ for Tensor arguments."""
+  # Statically shaped tensors: length is known ahead of time.
+  if s.shape.ndims and s.shape[0].value is not None:
+    return s.shape[0].value
+
+  # Static shape of unknown dimensions: use dynamic shape but statically
+  # chech that it's a scalar.
+  shape = array_ops.shape(s)
+
+  assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+  if shape.shape[0] == 0:
+    raise ValueError(
+        'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+  if shape.shape[0].value is not None:
+    return array_ops.shape(s)[0]
+
+  # Fully dynamic shape: use ops.
+  rank = array_ops.rank(s)
+
+  def raise_zero_rank_error():
+    msg = gen_string_ops.string_join(
+        ['len requires non-zero rank, got ',
+         gen_string_ops.as_string(rank)])
+    with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+      return constant_op.constant(0, dtype=dtypes.int32)
+
+  return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+                               raise_zero_rank_error)
+
+
+def _py_len(s):
+  return len(s)
+
+
+def print_(*objects, **kwargs):
+  # Note: Python 2.6 doesn't support explicit keywords after starargs.
+  unknown_kwargs = tuple(
+      set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+  if unknown_kwargs:
+    raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+  # TODO(mdan): use logging_ops.Print when py_func is not supported.
+  return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+  """Overload of print_ as a py_func implementation."""
+  override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+  if 'flush' not in override_kwargs:
+    # Defaulting to flushing the console in graph mode, which helps reduce
+    # garbled output in IPython.
+    override_kwargs['flush'] = True
+
+  def print_wrapper(*vals):
+    if six.PY3:
+      # TensorFlow doesn't seem to generate Unicode when passing strings to
+      # py_func. This causes the print to add a "b'" wrapper to the output,
+      # which is probably never what you want.
+      vals = tuple(
+          v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+    six.print_(*vals, **override_kwargs)
+
+  return py_func.wrap_py_func(
+      print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+  if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+    return _tf_range(start_or_stop, stop, step)
+  return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+  # TODO(mdan): We should optimize this when a full tensor is not required.
+  if step is not UNDEFINED:
+    return math_ops.range(start_or_stop, stop, step)
+  if stop is not UNDEFINED:
+    return math_ops.range(start_or_stop, stop)
+  return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+  if step is not UNDEFINED:
+    return range(start_or_stop, stop, step)
+  if stop is not UNDEFINED:
+    return range(start_or_stop, stop)
+  return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+  SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+    'abs': abs_,
+    'float': float_,
+    'int': int_,
+    'len': len_,
+    'print': print_,
+    'range': range_,
+    'xrange': range_,
+}
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000..a021263
--- /dev/null
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+  def test_abs(self):
+    self.assertEqual(py_builtins.abs_(-1), 1)
+    with self.test_session() as sess:
+      t = py_builtins.abs_(constant_op.constant(-1))
+      self.assertEqual(sess.run(t), 1)
+      t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+      self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+  def test_float(self):
+    self.assertEqual(py_builtins.float_(10), 10.0)
+    self.assertEqual(py_builtins.float_('10.0'), 10.0)
+    with self.test_session() as sess:
+      t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+      self.assertEqual(sess.run(t), 1.0)
+      st = py_builtins.float_(constant_op.constant('1.0'))
+      self.assertEqual(sess.run(st), 1.0)
+
+  def test_int(self):
+    self.assertEqual(py_builtins.int_(10.0), 10)
+    self.assertEqual(py_builtins.int_('11', 2), 3)
+    with self.test_session() as sess:
+      t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+      self.assertEqual(sess.run(t), 1)
+      st = py_builtins.int_(constant_op.constant('1'))
+      self.assertEqual(sess.run(st), 1)
+      st = py_builtins.int_(constant_op.constant('1'), 10)
+      self.assertEqual(sess.run(st), 1)
+
+  def test_int_unsupported_base(self):
+    t = constant_op.constant(1, dtype=dtypes.float64)
+    with self.assertRaises(NotImplementedError):
+      py_builtins.int_(t, 2)
+
+  def test_len(self):
+    self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+    with self.test_session() as sess:
+      t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+      self.assertEqual(t, 3)
+      ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+      self.assertEqual(sess.run(ta), 5)
+      tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+      self.assertEqual(sess.run(tl), 3)
+
+  def test_len_scalar(self):
+    with self.assertRaises(ValueError):
+      py_builtins.len_(constant_op.constant(1))
+
+  def test_len_dynamic_shape(self):
+    with self.test_session() as sess:
+      p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+      t = py_builtins.len_(p)
+      self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+      with self.assertRaises(errors_impl.InvalidArgumentError):
+        t = py_builtins.len_(p)
+        sess.run(t, {p: 1})
+
+  def test_print_tensors(self):
+    try:
+      out_capturer = six.StringIO()
+      sys.stdout = out_capturer
+      with self.test_session() as sess:
+        sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+        self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+    finally:
+      sys.stdout = sys.__stdout__
+
+  def test_print_complex(self):
+    try:
+      out_capturer = six.StringIO()
+      sys.stdout = out_capturer
+      with self.test_session() as sess:
+        sess.run(
+            py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+        self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+    finally:
+      sys.stdout = sys.__stdout__
+
+  def test_range(self):
+    self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+    self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+    self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+  def test_range_tensor(self):
+    with self.test_session() as sess:
+      r = py_builtins.range_(constant_op.constant(3))
+      self.assertAllEqual(sess.run(r), [0, 1, 2])
+      r = py_builtins.range_(1, constant_op.constant(3))
+      self.assertAllEqual(sess.run(r), [1, 2])
+      r = py_builtins.range_(2, 0, constant_op.constant(-1))
+      self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/python/autograph/operators/slices.py
similarity index 92%
rename from tensorflow/contrib/autograph/operators/slices.py
rename to tensorflow/python/autograph/operators/slices.py
index 04fbeb2..2b7f5ad 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/python/autograph/operators/slices.py
@@ -22,6 +22,7 @@
 
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_string_ops
 from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import tensor_array_ops
 
@@ -57,6 +58,8 @@
   elif tensor_util.is_tensor(target):
     if target.dtype == dtypes.variant:
       return _tf_tensor_list_get_item(target, i, opts)
+    elif target.dtype == dtypes.string and target.shape.ndims == 0:
+      return _tf_tensor_string_get_item(target, i)
     else:
       return _tf_tensor_get_item(target, i)
   else:
@@ -82,6 +85,12 @@
   return target[i]
 
 
+def _tf_tensor_string_get_item(target, i):
+  """Overload of get_item that stages a Tensor string read."""
+  x = gen_string_ops.substr(target, i, 1)
+  return x
+
+
 def _py_get_item(target, i):
   """Overload of get_item that executes a Python list modification."""
   return target[i]
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
similarity index 75%
rename from tensorflow/contrib/autograph/operators/slices_test.py
rename to tensorflow/python/autograph/operators/slices_test.py
index 56aafe0..d8b8418 100644
--- a/tensorflow/contrib/autograph/operators/slices_test.py
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.autograph.operators import slices
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import list_ops
 from tensorflow.python.platform import test
@@ -46,6 +46,21 @@
     with self.cached_session() as sess:
       self.assertAllEqual(sess.run(t), [3, 4])
 
+  def test_get_item_tensor_string(self):
+    initial_str = constant_op.constant('abcd')
+    t = slices.get_item(initial_str, 1,
+                        slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(t), b'b')
+
+    initial_list_str = constant_op.constant(['abcd', 'bcde'])
+    t = slices.get_item(initial_list_str, 1,
+                        slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+    with self.test_session() as sess:
+      self.assertEqual(sess.run(t), b'bcde')
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/BUILD
rename to tensorflow/python/autograph/pyct/BUILD
diff --git a/tensorflow/contrib/autograph/pyct/__init__.py b/tensorflow/python/autograph/pyct/__init__.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/__init__.py
rename to tensorflow/python/autograph/pyct/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/anno.py
rename to tensorflow/python/autograph/pyct/anno.py
diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/python/autograph/pyct/anno_test.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/anno_test.py
rename to tensorflow/python/autograph/pyct/anno_test.py
index 5ef4da6..1f87387 100644
--- a/tensorflow/contrib/autograph/pyct/anno_test.py
+++ b/tensorflow/python/autograph/pyct/anno_test.py
@@ -20,7 +20,7 @@
 
 import ast
 
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/ast_util.py
rename to tensorflow/python/autograph/pyct/ast_util.py
index d7453b0..7df3b88 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -22,8 +22,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
 
 
 class CleanCopier(object):
diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/pyct/ast_util_test.py
rename to tensorflow/python/autograph/pyct/ast_util_test.py
index 2293c89..b1577c4 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util_test.py
+++ b/tensorflow/python/autograph/pyct/ast_util_test.py
@@ -22,11 +22,11 @@
 import collections
 import textwrap
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
similarity index 99%
rename from tensorflow/contrib/autograph/pyct/cfg.py
rename to tensorflow/python/autograph/pyct/cfg.py
index ba51dcf..1433f9a 100644
--- a/tensorflow/contrib/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -33,7 +33,7 @@
 import gast
 # pylint:enable=g-bad-import-order
 
-from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import compiler
 
 
 class Node(object):
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
similarity index 99%
rename from tensorflow/contrib/autograph/pyct/cfg_test.py
rename to tensorflow/python/autograph/pyct/cfg_test.py
index 9d0a85d..bd82e70 100644
--- a/tensorflow/contrib/autograph/pyct/cfg_test.py
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -18,8 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD
similarity index 94%
rename from tensorflow/contrib/autograph/pyct/common_transformers/BUILD
rename to tensorflow/python/autograph/pyct/common_transformers/BUILD
index fe630ef..5e2f8f3 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD
@@ -26,7 +26,7 @@
         "@six_archive//:six",
         # TODO(aqj) Revisit this dependency direction when pyct is more
         # modularized
-        "//tensorflow/contrib/autograph/pyct",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
rename to tensorflow/python/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/common_transformers/anf.py
rename to tensorflow/python/autograph/pyct/common_transformers/anf.py
index e42f679..192621b 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py
@@ -29,8 +29,8 @@
 import gast
 import six
 
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
 
 
 class DummyGensym(object):
@@ -394,10 +394,16 @@
   # just recur.
 
   def visit_List(self, node):
-    return self._visit_strict_expression(node)
+    node = self.generic_visit(node)
+    if not isinstance(node.ctx, gast.Store):
+      self._ensure_fields_trivial(node)
+    return node
 
   def visit_Tuple(self, node):
-    return self._visit_strict_expression(node)
+    node = self.generic_visit(node)
+    if not isinstance(node.ctx, gast.Store):
+      self._ensure_fields_trivial(node)
+    return node
 
 
 def transform(node, entity_info, gensym_source=None):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
similarity index 89%
rename from tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
rename to tensorflow/python/autograph/pyct/common_transformers/anf_test.py
index 9519748..ccc7e4c 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -20,10 +20,10 @@
 
 import textwrap
 
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.common_transformers import anf
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.common_transformers import anf
 from tensorflow.python.platform import test
 
 
@@ -165,6 +165,46 @@
 
     self.assert_body_anfs_as_expected(expected_result, test_function)
 
+  def test_nested_multi_value_assign(self):
+
+    def test_function(a, b, c):
+      x, y = a, a + b
+      (z, y), x = (c, y + b), x + a
+      return z, (y, x)
+
+    def expected_result(a, b, c):
+      tmp_1001 = a + b
+      x, y = a, tmp_1001
+      tmp_1002 = y + b
+      tmp_1003 = (c, tmp_1002)
+      tmp_1004 = x + a
+      (z, y), x = tmp_1003, tmp_1004
+      tmp_1005 = y, x
+      tmp_1006 = z, tmp_1005
+      return tmp_1006
+
+    self.assert_body_anfs_as_expected(expected_result, test_function)
+
+  def test_deeply_nested_multi_value_assign(self):
+
+    def test_function(a):
+      [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+      return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+    def expected_result(a):
+      [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+      tmp_1001 = b, c
+      tmp_1002 = [d, e]
+      tmp_1003 = [tmp_1001, tmp_1002]
+      tmp_1004 = f, g
+      tmp_1005 = h, i, j
+      tmp_1006 = tmp_1003, tmp_1004
+      tmp_1007 = [tmp_1005, k]
+      tmp_1008 = [tmp_1006, tmp_1007]
+      return tmp_1008
+
+    self.assert_body_anfs_as_expected(expected_result, test_function)
+
   def test_local_definition_and_binary_compare(self):
 
     def test_function():
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/compiler.py
rename to tensorflow/python/autograph/pyct/compiler.py
index f9cee10..9e1b6bd 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -30,7 +30,7 @@
 import astor
 import gast
 
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
 
 
 def ast_to_source(node, indentation='  '):
diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/compiler_test.py
rename to tensorflow/python/autograph/pyct/compiler_test.py
index cf783da..6fa289d 100644
--- a/tensorflow/contrib/autograph/pyct/compiler_test.py
+++ b/tensorflow/python/autograph/pyct/compiler_test.py
@@ -22,8 +22,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 from tensorflow.python.util import tf_inspect
 
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/inspect_utils.py
rename to tensorflow/python/autograph/pyct/inspect_utils.py
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/inspect_utils_test.py
rename to tensorflow/python/autograph/pyct/inspect_utils_test.py
index 1a212f6..f3eb027 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -22,7 +22,7 @@
 
 import six
 
-from tensorflow.contrib.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import inspect_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/origin_info.py
rename to tensorflow/python/autograph/pyct/origin_info.py
index b60651a..4c7c416 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -23,9 +23,9 @@
 import gast
 import six
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
similarity index 93%
rename from tensorflow/contrib/autograph/pyct/origin_info_test.py
rename to tensorflow/python/autograph/pyct/origin_info_test.py
index eeaa130..6b9c30d 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -18,10 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/parser.py
rename to tensorflow/python/autograph/pyct/parser.py
diff --git a/tensorflow/contrib/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/parser_test.py
rename to tensorflow/python/autograph/pyct/parser_test.py
index 007a4c6..d0b465e 100644
--- a/tensorflow/contrib/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -20,7 +20,7 @@
 
 import textwrap
 
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import parser
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/pretty_printer.py
rename to tensorflow/python/autograph/pyct/pretty_printer.py
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py b/tensorflow/python/autograph/pyct/pretty_printer_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/pretty_printer_test.py
rename to tensorflow/python/autograph/pyct/pretty_printer_test.py
index 0cb48f3..1c76744 100644
--- a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer_test.py
@@ -20,7 +20,7 @@
 
 import ast
 
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import pretty_printer
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/qual_names.py
rename to tensorflow/python/autograph/pyct/qual_names.py
index fb81404..334cbd7 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/python/autograph/pyct/qual_names.py
@@ -29,8 +29,8 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
 
 
 class Symbol(collections.namedtuple('Symbol', ['name'])):
diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/qual_names_test.py
rename to tensorflow/python/autograph/pyct/qual_names_test.py
index c793c2bb..2da4dfd 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names_test.py
+++ b/tensorflow/python/autograph/pyct/qual_names_test.py
@@ -20,11 +20,11 @@
 
 import textwrap
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.qual_names import resolve
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.qual_names import resolve
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD
similarity index 82%
rename from tensorflow/contrib/autograph/pyct/static_analysis/BUILD
rename to tensorflow/python/autograph/pyct/static_analysis/BUILD
index 92eacba..4a4ccdc 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD
@@ -27,9 +27,9 @@
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/utils",
         "//tensorflow/python:util",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/utils",
         "@gast_archive//:gast",
     ],
 )
@@ -41,8 +41,8 @@
     tags = ["no_windows"],
     deps = [
         ":static_analysis",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
         "@gast_archive//:gast",
     ],
 )
@@ -54,8 +54,8 @@
     tags = ["no_windows"],
     deps = [
         ":static_analysis",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -65,8 +65,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":static_analysis",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -76,8 +76,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":static_analysis",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
     ],
 )
 
@@ -87,8 +87,8 @@
     srcs_version = "PY2AND3",
     deps = [
         ":static_analysis",
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/utils",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/utils",
     ],
 )
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
rename to tensorflow/python/autograph/pyct/static_analysis/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/static_analysis/activity.py
rename to tensorflow/python/autograph/pyct/static_analysis/activity.py
index a0182da..9cb5991 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -25,10 +25,10 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 
 # TODO(mdan): Add support for PY3 (e.g. Param vs arg).
 # TODO(alexbw): Ignore named literals (e.g. None)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
rename to tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index e940516..d4a6ce8 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -20,13 +20,13 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/python/autograph/pyct/static_analysis/annos.py
similarity index 100%
rename from tensorflow/contrib/autograph/pyct/static_analysis/annos.py
rename to tensorflow/python/autograph/pyct/static_analysis/annos.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
similarity index 90%
rename from tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
rename to tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 2d8f922..48b442f 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -25,9 +25,14 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
 
 
 class LiveValueResolver(transformer.Base):
@@ -66,6 +71,8 @@
             # If the symbol value is for example a primitive, then it will not
             # have a name.
             pass
+        elif node.id in _special_symbols:
+          anno.setanno(node, 'live_val', _special_symbols[node.id])
         else:
           pass
           # TODO(mdan): Should we raise an error here?
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
similarity index 87%
rename from tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
rename to tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
index fe30511..882c380 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
@@ -20,15 +20,15 @@
 
 import six
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
similarity index 96%
rename from tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
rename to tensorflow/python/autograph/pyct/static_analysis/liveness.py
index bf29d86..41c903b 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -26,10 +26,10 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
 
 
 class Analyzer(cfg.GraphVisitor):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
similarity index 89%
rename from tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
rename to tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index d53adb2..0d5f369 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -18,13 +18,13 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import liveness
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
rename to tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
index 7f2b379..9aaf318 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
@@ -30,10 +30,10 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
 
 
 class Definition(object):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
similarity index 94%
rename from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
rename to tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index 243fe80..373a2cb 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -18,13 +18,13 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
similarity index 97%
rename from tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
rename to tensorflow/python/autograph/pyct/static_analysis/type_info.py
index 835d519..edb2ef0 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
@@ -43,9 +43,9 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
similarity index 91%
rename from tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
rename to tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
index 404311b..34ba3d2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
@@ -18,15 +18,15 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
 from tensorflow.python.client import session
 from tensorflow.python.platform import test
 from tensorflow.python.training import training
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
similarity index 95%
rename from tensorflow/contrib/autograph/pyct/templates.py
rename to tensorflow/python/autograph/pyct/templates.py
index 5831d57..68c2a35 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -26,10 +26,10 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
 
 
 class ReplaceTransformer(gast.NodeTransformer):
@@ -113,7 +113,7 @@
     if isinstance(node, gast.Attribute):
       self._check_inner_children_have_context(node.value)
       self._check_has_context(node)
-    elif isinstance(node, gast.Tuple):
+    elif isinstance(node, (gast.Tuple, gast.List)):
       for e in node.elts:
         self._check_inner_children_have_context(e)
       self._check_has_context(node)
@@ -142,7 +142,7 @@
     if isinstance(node, gast.Attribute):
       self._set_inner_child_context(node.value, gast.Load())
       node.ctx = ctx
-    elif isinstance(node, gast.Tuple):
+    elif isinstance(node, (gast.Tuple, gast.List)):
       for e in node.elts:
         self._set_inner_child_context(e, ctx)
       node.ctx = ctx
@@ -191,7 +191,7 @@
 
     # Preserve the target context.
     for n in new_nodes:
-      if isinstance(n, gast.Tuple):
+      if isinstance(n, (gast.Tuple, gast.List)):
         for e in n.elts:
           self._set_inner_child_context(e, node.ctx)
       if isinstance(n, gast.Attribute):
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
similarity index 75%
rename from tensorflow/contrib/autograph/pyct/templates_test.py
rename to tensorflow/python/autograph/pyct/templates_test.py
index 77e8ff6..66268cf 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -22,9 +22,9 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
 from tensorflow.python.platform import test
 
 
@@ -110,6 +110,42 @@
     self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
     self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
 
+  def test_replace_list_context(self):
+    template = """
+      def test_fn(foo):
+        foo = 0
+    """
+
+    node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+  def test_replace_tuple_context(self):
+    template = """
+      def test_fn(foo):
+        foo = 0
+    """
+
+    node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+    self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+    self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+  def test_replace_complex_context(self):
+    template = """
+      def test_fn(foo):
+        foo = 0
+    """
+
+    node = templates.replace(
+        template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+    self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+    function_call_arg = node.body[0].targets[0].value.args[0]
+    self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+    self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+    self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
   def test_replace_call_keyword(self):
     template = """
       def test_fn():
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD
similarity index 85%
rename from tensorflow/contrib/autograph/pyct/testing/BUILD
rename to tensorflow/python/autograph/pyct/testing/BUILD
index 29a9244..c244cbd 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/python/autograph/pyct/testing/BUILD
@@ -22,8 +22,8 @@
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        "//tensorflow/contrib/autograph/pyct",
-        "//tensorflow/contrib/autograph/utils",
+        "//tensorflow/python/autograph/pyct",
+        "//tensorflow/python/autograph/utils",
         "@gast_archive//:gast",
     ],
 )
@@ -41,8 +41,8 @@
     ],
     deps = [
         ":testing",
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python/autograph/pyct",
         "@gast_archive//:gast",
     ],
 )
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/python/autograph/pyct/testing/codegen.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/testing/codegen.py
rename to tensorflow/python/autograph/pyct/testing/codegen.py
index 279e7c0..78b2439 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen.py
@@ -24,7 +24,7 @@
 import gast
 import numpy as np
 
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import templates
 
 
 class NodeSampler(object):
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/python/autograph/pyct/testing/codegen_test.py
similarity index 91%
rename from tensorflow/contrib/autograph/pyct/testing/codegen_test.py
rename to tensorflow/python/autograph/pyct/testing/codegen_test.py
index 255c3b2..71665be 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen_test.py
@@ -20,8 +20,8 @@
 
 import numpy as np
 
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct.testing import codegen
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct.testing import codegen
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/transformer.py
rename to tensorflow/python/autograph/pyct/transformer.py
index 969ca12..520f503 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -23,9 +23,9 @@
 import gast
 import six
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import pretty_printer
 
 
 class AutographParseError(SyntaxError):
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py
similarity index 98%
rename from tensorflow/contrib/autograph/pyct/transformer_test.py
rename to tensorflow/python/autograph/pyct/transformer_test.py
index a37e922..23bf9a8 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/python/autograph/pyct/transformer_test.py
@@ -20,9 +20,9 @@
 
 import gast
 
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD
similarity index 93%
rename from tensorflow/contrib/autograph/utils/BUILD
rename to tensorflow/python/autograph/utils/BUILD
index d2b399f..22451d4 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/python/autograph/utils/BUILD
@@ -20,39 +20,28 @@
     name = "utils",
     srcs = [
         "__init__.py",
-        "builtins.py",
         "context_managers.py",
         "misc.py",
         "multiple_dispatch.py",
         "py_func.py",
         "tensor_list.py",
+        "tensors.py",
         "testing.py",
         "type_check.py",
     ],
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
-        "//tensorflow/contrib/autograph/pyct",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:list_ops",
         "//tensorflow/python:script_ops",
+        "//tensorflow/python/autograph/pyct",
         "//tensorflow/python/data/ops:dataset_ops",
         "@six_archive//:six",
     ],
 )
 
 py_test(
-    name = "builtins_test",
-    srcs = ["builtins_test.py"],
-    srcs_version = "PY2AND3",
-    tags = ["no_windows"],
-    deps = [
-        ":utils",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
-py_test(
     name = "context_managers_test",
     srcs = ["context_managers_test.py"],
     srcs_version = "PY2AND3",
@@ -113,3 +102,13 @@
         "//tensorflow/python:list_ops",
     ],
 )
+
+py_test(
+    name = "tensors_test",
+    srcs = ["tensors_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":utils",
+        "//tensorflow/python:client_testlib",
+    ],
+)
diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py
new file mode 100644
index 0000000..c781958
--- /dev/null
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility module that contains APIs usable in the generated code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
+from tensorflow.python.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
+from tensorflow.python.autograph.utils.py_func import wrap_py_func
+from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
+from tensorflow.python.autograph.utils.testing import fake_tf
+from tensorflow.python.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/autograph/utils/context_managers.py b/tensorflow/python/autograph/utils/context_managers.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/context_managers.py
rename to tensorflow/python/autograph/utils/context_managers.py
diff --git a/tensorflow/contrib/autograph/utils/context_managers_test.py b/tensorflow/python/autograph/utils/context_managers_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/utils/context_managers_test.py
rename to tensorflow/python/autograph/utils/context_managers_test.py
index 42e2772..7f0a15b 100644
--- a/tensorflow/contrib/autograph/utils/context_managers_test.py
+++ b/tensorflow/python/autograph/utils/context_managers_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils import context_managers
+from tensorflow.python.autograph.utils import context_managers
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import tensor_array_ops
diff --git a/tensorflow/contrib/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/misc.py
rename to tensorflow/python/autograph/utils/misc.py
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/python/autograph/utils/misc_test.py
similarity index 91%
rename from tensorflow/contrib/autograph/utils/misc_test.py
rename to tensorflow/python/autograph/utils/misc_test.py
index 71e358c..8d2b0d6 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/python/autograph/utils/misc_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.misc import alias_tensors
 from tensorflow.python.framework.constant_op import constant
 from tensorflow.python.ops.variables import Variable
 from tensorflow.python.platform import test
@@ -31,7 +31,7 @@
 
     new_a = alias_tensors(a)
     self.assertFalse(new_a is a)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(1, sess.run(new_a))
 
   def test_alias_tensors(self):
@@ -46,7 +46,7 @@
     self.assertTrue(new_v is v)
     self.assertTrue(new_s is s)
     self.assertTrue(new_l is l)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(1, sess.run(new_a))
 
 
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
similarity index 85%
rename from tensorflow/contrib/autograph/utils/multiple_dispatch.py
rename to tensorflow/python/autograph/utils/multiple_dispatch.py
index 70eef56..107c8f7 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -18,20 +18,10 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
+from tensorflow.python.autograph.utils.type_check import is_tensor
 from tensorflow.python.ops import control_flow_ops
 
 
-def dynamic_is(left, right):
-  # TODO(alexbw) if we're sure we should leave 'is' in place,
-  # then change the semantics in converters/logical_expressions.py
-  return left is right
-
-
-def dynamic_is_not(left, right):
-  return left is not right
-
-
 def run_cond(condition, true_fn, false_fn):
   """Type-dependent functional conditional.
 
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
new file mode 100644
index 0000000..2a77c89
--- /dev/null
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multiple_dispatch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import multiple_dispatch
+from tensorflow.python.client.session import Session
+from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.platform import test
+
+
+class MultipleDispatchTest(test.TestCase):
+
+  def test_run_cond_python(self):
+    true_fn = lambda: (2,)
+    false_fn = lambda: (3,)
+    self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2)
+    self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3)
+
+  def test_run_cond_tf(self):
+    true_fn = lambda: (constant(2),)
+    false_fn = lambda: (constant(3),)
+    with Session() as sess:
+      out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn)
+      self.assertEqual(sess.run(out), 2)
+      out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
+      self.assertEqual(sess.run(out), 3)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/python/autograph/utils/py_func.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/py_func.py
rename to tensorflow/python/autograph/utils/py_func.py
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/python/autograph/utils/py_func_test.py
similarity index 93%
rename from tensorflow/contrib/autograph/utils/py_func_test.py
rename to tensorflow/python/autograph/utils/py_func_test.py
index 2468263..1c220d9 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/python/autograph/utils/py_func_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.utils import py_func
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.platform import test
@@ -31,7 +31,7 @@
     def test_fn(a, b, c):
       return a + b + c
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = py_func.wrap_py_func(test_fn, dtypes.int64,
                                     (1, constant_op.constant(1), 1))
       self.assertEqual(3, sess.run(result))
@@ -52,7 +52,7 @@
     def test_fn(a, b):
       return a * b.foo
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
       self.assertEqual(35, sess.run(result))
       result = py_func.wrap_py_func(test_fn, dtypes.int64,
@@ -69,7 +69,7 @@
     def test_fn(a, b, c, d):
       return a * b.foo + c * d.foo
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
           'c': 11,
           'd': TestClass(13)
@@ -89,7 +89,7 @@
     def test_fn(_):
       side_counter[0] += 1
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
       self.assertEqual(1, sess.run(result))
       self.assertEqual([1], side_counter)
diff --git a/tensorflow/contrib/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/tensor_list.py
rename to tensorflow/python/autograph/utils/tensor_list.py
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
similarity index 93%
rename from tensorflow/contrib/autograph/utils/tensor_list_test.py
rename to tensorflow/python/autograph/utils/tensor_list_test.py
index d58489e..697c166 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -18,7 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.autograph.utils import tensor_list as tl
+from tensorflow.python.autograph.utils import tensor_list as tl
 from tensorflow.python.client.session import Session
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
@@ -42,18 +42,18 @@
     l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
     l = tl.dynamic_list_append(l, 1)
     s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(sess.run(s), [1])
 
     l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
     l = tl.dynamic_list_append(l, 1)
     s = l.stack()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(sess.run(s), [1])
 
     l = tl.TensorList(self._shape(()), dtypes.int32)
     l = tl.dynamic_list_append(l, 1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(sess.run(l[0]), 1)
 
   def test_list_append_python(self):
@@ -107,7 +107,7 @@
     l0 = l[0]
     l[0] = b
     l1 = l[0]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       l0, l1, a, b = sess.run([l0, l1, a, b])
       self.assertEqual(l0, a)
       self.assertEqual(l1, b)
diff --git a/tensorflow/python/autograph/utils/tensors.py b/tensorflow/python/autograph/utils/tensors.py
new file mode 100644
index 0000000..fa5db81
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+  return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+  # TODO(mdan): This is just a heuristic.
+  # With TF lacking support for templated types, this is unfortunately the
+  # closest we can get right now. A dedicated op ought to be possible to
+  # construct.
+  return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+          not t.shape.ndims)
diff --git a/tensorflow/python/autograph/utils/tensors_test.py b/tensorflow/python/autograph/utils/tensors_test.py
new file mode 100644
index 0000000..1e7cfec
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+  def _simple_tensor_array(self):
+    return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+  def _simple_tensor_list(self):
+    return list_ops.empty_tensor_list(
+        element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+  def _simple_list_of_tensors(self):
+    return [constant_op.constant(1), constant_op.constant(2)]
+
+  def test_is_tensor_array(self):
+    self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+    self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+    self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+    self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+    self.assertFalse(tensors.is_tensor_array(None))
+
+  def test_is_tensor_list(self):
+    self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+    self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+    self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+    self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+    self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/testing.py
rename to tensorflow/python/autograph/utils/testing.py
diff --git a/tensorflow/contrib/autograph/utils/type_check.py b/tensorflow/python/autograph/utils/type_check.py
similarity index 100%
rename from tensorflow/contrib/autograph/utils/type_check.py
rename to tensorflow/python/autograph/utils/type_check.py
diff --git a/tensorflow/contrib/autograph/utils/type_check_test.py b/tensorflow/python/autograph/utils/type_check_test.py
similarity index 95%
rename from tensorflow/contrib/autograph/utils/type_check_test.py
rename to tensorflow/python/autograph/utils/type_check_test.py
index 3b67b71..b3d1304 100644
--- a/tensorflow/contrib/autograph/utils/type_check_test.py
+++ b/tensorflow/python/autograph/utils/type_check_test.py
@@ -20,7 +20,7 @@
 
 import numpy
 
-from tensorflow.contrib.autograph.utils import type_check
+from tensorflow.python.autograph.utils import type_check
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_util
 from tensorflow.python.platform import test
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 1841dd9..ae0ad27 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1132,7 +1132,7 @@
         for details of the allowable fetch types.
       feed_list: (Optional.) A list of `feed_dict` keys. See
         `tf.Session.run` for details of the allowable feed key types.
-      accept_options: (Optional.) Iff `True`, the returned `Callable` will be
+      accept_options: (Optional.) If `True`, the returned `Callable` will be
         able to accept `tf.RunOptions` and `tf.RunMetadata` as optional
         keyword arguments `options` and `run_metadata`, respectively, with
         the same syntax and semantics as `tf.Session.run`, which is useful
@@ -1302,9 +1302,7 @@
           node_def = op.node_def
         except KeyError:
           pass
-      if (self._config is not None and
-          self._config.experimental.client_handles_error_formatting):
-        message = error_interpolation.interpolate(message, self._graph)
+      message = error_interpolation.interpolate(message, self._graph)
       raise type(e)(node_def, op, message)
 
   def _extend_graph(self):
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 052be68..4afc639 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -49,6 +49,8 @@
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_control_flow_ops
+# Import gradients to resolve circular imports
+from tensorflow.python.ops import gradients  # pylint: disable=unused-import
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
@@ -1760,7 +1762,7 @@
     with self.assertRaises(ValueError):
       session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
                                                         feed_fn1, feed_fn2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       np1 = np.array([1.0, 1.5, 2.0, 2.5])
       np2 = np.array([3.0, 3.5, 4.0, 4.5])
       squared_tensor = SquaredTensor(np2)
@@ -1920,7 +1922,7 @@
       pass
 
   def testAutoConvertAndCheckData(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = array_ops.placeholder(dtype=dtypes.string)
       with self.assertRaisesRegexp(
           TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index c046e9c..03effde 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -161,7 +161,7 @@
     cpu_max = maximums[
         'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums[cpuname]
     # At least num1 + num2, both float32s (4 bytes each)
-    self.assertGreater(cpu_max.num_bytes, 8)
+    self.assertGreaterEqual(cpu_max.num_bytes, 8)
     self.assertGreater(cpu_max.timestamp, 0)
     self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
     self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 459f494..1a1ed04 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util.tf_export import tf_export
 
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 4)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 13)
 
 
 @tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 23c9824..631b87a 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -137,6 +137,8 @@
     size = "small",
     srcs = ["interleave_dataset_op_test.py"],
     additional_deps = [
+        "@absl_py//absl/testing:parameterized",
+        "//third_party/py/numpy",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:dtypes",
@@ -154,6 +156,7 @@
     size = "small",
     srcs = ["map_dataset_op_test.py"],
     additional_deps = [
+        "@absl_py//absl/testing:parameterized",
         "//third_party/py/numpy",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index 89de55d..c48708a 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -82,7 +82,7 @@
     self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
                      [t.shape.as_list() for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -111,7 +111,7 @@
     iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(get_next)
 
@@ -131,7 +131,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(2):
         actual = sess.run(get_next)
@@ -158,7 +158,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(2):
         actual = sess.run(get_next)
@@ -188,7 +188,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       actual = sess.run(get_next)
       expected = sparse_tensor.SparseTensorValue(
@@ -214,7 +214,7 @@
         .make_initializable_iterator())
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer)
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -262,7 +262,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_op,
           feed_dict={
@@ -307,7 +307,7 @@
             batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.DataLossError):
         sess.run(get_next)
 
@@ -318,7 +318,7 @@
             batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = sess.run(get_next)
       self.assertAllEqual([[], [], [], []], result)
       with self.assertRaises(errors.OutOfRangeError):
@@ -342,7 +342,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test with random sequence lengths, and max padding.
       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
       sess.run(
@@ -381,7 +381,7 @@
         (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
     padded_dataset = dataset.padded_batch(
         2, padded_shapes=([None], [None]), padding_values=('', 0))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       next_element = padded_dataset.make_one_shot_iterator().get_next()
       sess.run(next_element)
 
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index 4f7fd35..d5f5b2f 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -68,7 +68,7 @@
 
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # First run without caching to collect the "ground truth".
       sess.run(init_fifo_op)
       elements = []
@@ -132,7 +132,7 @@
     get_next1 = iterator1.get_next()
     get_next2 = iterator2.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
       sess.run(get_next1)  # this should succeed
@@ -162,7 +162,7 @@
     get_next1 = iterator1.get_next()
     get_next2 = iterator2.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
       elements = []
@@ -217,7 +217,7 @@
       uncached_iterator = uncached_dataset.make_initializable_iterator()
       uncached_next = uncached_iterator.get_next()
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
 
         sess.run(repeat_count.initializer)
         sess.run(cached_iterator.initializer)
@@ -261,7 +261,7 @@
 
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize with an empty upstream and a missing cache file (should
       # throw errors.OutOfRangeError immediately).
       sess.run(init_cache_op, feed_dict={count_placeholder: 0})
@@ -278,7 +278,7 @@
     i1 = d1.make_initializable_iterator()
     i2 = d2.make_initializable_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(i1.initializer)
 
       self.assertEqual(1, sess.run(i1.get_next()))
@@ -304,7 +304,7 @@
 
     expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i, expected in enumerate(expected_values):
         self.assertEqual(expected, sess.run(n),
                          "Unexpected value at index %s" % i)
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 159218c..5dfb84f 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -49,7 +49,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(9):
         result = sess.run(get_next)
@@ -83,7 +83,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(9):
         result = sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index ea5b41e..e43564a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -50,7 +50,7 @@
     self.assertEqual([c.shape for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       results = sess.run(get_next)
       for component, result_component in zip(components, results):
@@ -84,7 +84,7 @@
         [tensor_shape.TensorShape(c.dense_shape) for c in components],
         [shape for shape in iterator.output_shapes])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       results = sess.run(get_next)
       for component, result_component in zip(components, results):
@@ -115,7 +115,7 @@
         if sparse_tensor.is_sparse(c) else c.shape for c in components
     ], [shape for shape in iterator.output_shapes])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       results = sess.run(get_next)
       for component, result_component in zip(components, results):
@@ -142,7 +142,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(4):
         results = sess.run(get_next)
@@ -172,7 +172,7 @@
         [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
         [shape for shape in iterator.output_shapes])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       expected = [
           (sparse_tensor.SparseTensorValue(
@@ -232,7 +232,7 @@
         if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
     ], [shape for shape in iterator.output_shapes])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       expected = [
           (sparse_tensor.SparseTensorValue(
@@ -283,7 +283,7 @@
     self.assertEqual((), iterator.output_shapes["foo"])
     self.assertEqual((1,), iterator.output_shapes["bar"])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(3):
         results = sess.run(get_next)
@@ -300,7 +300,7 @@
     init_op = iterator.initializer
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
 
       # Test with sparse tensor in the appropriate order.
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index fb55ae1..cd0c1dd 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -44,7 +44,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(2):  # Run twice to test reinitialization.
         sess.run(init_op)
         for _ in range(num_repeats):
@@ -61,7 +61,7 @@
         .make_one_shot_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(num_repeats):
         for elem in elem_sequence:
           self.assertAllEqual(elem, sess.run(get_next))
@@ -131,7 +131,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(num_inner_repeats * num_outer_repeats):
         for elem in input_list:
@@ -190,7 +190,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for elem in [0, 1]:
         for _ in range(num_parallel_iterators):
@@ -213,7 +213,7 @@
 
       self.assertEqual(dtype, get_next.dtype)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(init_op)
         for expected in [[1], [2], [3]]:
           next_val = sess.run(get_next)
@@ -234,7 +234,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for expected in [b"foo", b"bar", b"baz"]:
         next_val = sess.run(get_next)
@@ -255,7 +255,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual([1, 2, 3], sess.run(get_next))
       self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -278,7 +278,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual([1, 2, 3], sess.run(get_next))
       self.assertAllEqual([4, 5, 6], sess.run(get_next))
@@ -302,7 +302,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertEqual((1, 2), sess.run(get_next))
       self.assertEqual((3, 4), sess.run(get_next))
@@ -327,7 +327,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual(1, sess.run(get_next))
       self.assertAllEqual([2, 3], sess.run(get_next))
@@ -347,7 +347,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual(0, sess.run(get_next))
       self.assertAllEqual(1, sess.run(get_next))
@@ -405,7 +405,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
       for x in expected:
@@ -434,7 +434,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       expected = [(0, b"Hi!"),
                   (0, b"Hi!"), (1, b"Hi!"),
@@ -468,7 +468,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual(37, sess.run(get_next))
       self.assertAllEqual(37, sess.run(get_next))
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 2c4c11e..239aa85 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -27,7 +27,7 @@
 
   def testAsSerializedGraph(self):
     dataset = dataset_ops.Dataset.range(10)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       graph = graph_pb2.GraphDef().FromString(
           sess.run(dataset._as_serialized_graph()))
       self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 4f2216f..19944d3 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -59,7 +59,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test that we can dynamically feed a different modulus value for each
       # iterator.
       def do_test(count_val, modulus_val):
@@ -84,7 +84,7 @@
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(0, sess.run(get_next))
       self.assertEqual(1, sess.run(get_next))
       self.assertEqual(3, sess.run(get_next))
@@ -98,7 +98,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         if (i ** 2) % 2 == 0:
@@ -123,7 +123,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual(input_data[0], sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
@@ -151,7 +151,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(5):
         actual = sess.run(get_next)
@@ -169,7 +169,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         self.assertEqual((i, True), sess.run(get_next))
@@ -181,7 +181,7 @@
         lambda x: math_ops.equal(x % 2, 0))
     iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
     next_elements = [iterator.get_next() for iterator in iterators]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
 
 
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 350234a..1123cbf 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -43,7 +43,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in repeats:
         for _ in range(i):
@@ -62,7 +62,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for row in repeats:
         for i in row:
@@ -113,7 +113,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         for _ in range(i ** 2):
@@ -137,7 +137,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index 7dbf726..a35cee5 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -19,8 +19,10 @@
 
 import itertools
 
+from absl.testing import parameterized
+import numpy as np
+
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
@@ -28,7 +30,7 @@
 from tensorflow.python.platform import test
 
 
-class InterleaveDatasetTest(test.TestCase):
+class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
 
   def _interleave(self, lists, cycle_length, block_length):
     num_open = 0
@@ -97,84 +99,85 @@
         expected_elements, self._interleave(input_lists, 7, 2)):
       self.assertEqual(expected, produced)
 
-  def testInterleaveDataset(self):
-    input_values = array_ops.placeholder(dtypes.int64, shape=[None])
-    cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
-    block_length = array_ops.placeholder(dtypes.int64, shape=[])
+  @parameterized.named_parameters(
+      ("1", np.int64([4, 5, 6]), 1, 3, None),
+      ("2", np.int64([4, 5, 6]), 1, 3, 1),
+      ("3", np.int64([4, 5, 6]), 2, 1, None),
+      ("4", np.int64([4, 5, 6]), 2, 1, 1),
+      ("5", np.int64([4, 5, 6]), 2, 1, 2),
+      ("6", np.int64([4, 5, 6]), 2, 3, None),
+      ("7", np.int64([4, 5, 6]), 2, 3, 1),
+      ("8", np.int64([4, 5, 6]), 2, 3, 2),
+      ("9", np.int64([4, 5, 6]), 7, 2, None),
+      ("10", np.int64([4, 5, 6]), 7, 2, 1),
+      ("11", np.int64([4, 5, 6]), 7, 2, 3),
+      ("12", np.int64([4, 5, 6]), 7, 2, 5),
+      ("13", np.int64([4, 5, 6]), 7, 2, 7),
+      ("14", np.int64([]), 2, 3, None),
+      ("15", np.int64([0, 0, 0]), 2, 3, None),
+      ("16", np.int64([4, 0, 6]), 2, 3, None),
+      ("17", np.int64([4, 0, 6]), 2, 3, 1),
+      ("18", np.int64([4, 0, 6]), 2, 3, 2),
+  )
+  def testInterleaveDataset(self, input_values, cycle_length, block_length,
+                            num_parallel_calls):
+    count = 2
+    dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+        count).interleave(
+            lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+            cycle_length, block_length, num_parallel_calls)
+    get_next = dataset.make_one_shot_iterator().get_next()
 
-    repeat_count = 2
-
-    dataset = (
-        dataset_ops.Dataset.from_tensor_slices(input_values)
-        .repeat(repeat_count)
-        .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
-                    cycle_length, block_length))
-    iterator = dataset.make_initializable_iterator()
-    init_op = iterator.initializer
-    next_element = iterator.get_next()
+    def repeat(values, count):
+      result = []
+      for value in values:
+        result.append([value] * value)
+      return result * count
 
     with self.test_session() as sess:
-      # Cycle length 1 acts like `Dataset.flat_map()`.
-      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
-                                   cycle_length: 1, block_length: 3})
-
       for expected_element in self._interleave(
-          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
-        self.assertEqual(expected_element, sess.run(next_element))
+          repeat(input_values, count), cycle_length, block_length):
+        self.assertEqual(expected_element, sess.run(get_next))
 
-      # Cycle length > 1.
-      # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
-      #            6, 5, 6, 5, 6, 5, 6, 5]
-      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
-                                   cycle_length: 2, block_length: 1})
-      for expected_element in self._interleave(
-          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
-        self.assertEqual(expected_element, sess.run(next_element))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
+      for _ in range(2):
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
 
-      # Cycle length > 1 and block length > 1.
-      # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
-      #            5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
-      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
-                                   cycle_length: 2, block_length: 3})
-      for expected_element in self._interleave(
-          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
-        self.assertEqual(expected_element, sess.run(next_element))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
+  @parameterized.named_parameters(
+      ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
+      ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
+      ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
+      ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
+      ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
+      ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
+      ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
+      ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
+      ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
+      ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
+      ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
+      ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
+      ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
+  )
+  def testInterleaveErrorDataset(self,
+                                 input_values,
+                                 cycle_length,
+                                 block_length,
+                                 num_parallel_calls):
+    dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+        lambda x: array_ops.check_numerics(x, "message")).interleave(
+            dataset_ops.Dataset.from_tensors, cycle_length, block_length,
+            num_parallel_calls)
+    get_next = dataset.make_one_shot_iterator().get_next()
 
-      # Cycle length > len(input_values) * repeat_count.
-      # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
-      #            5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
-      sess.run(init_op, feed_dict={input_values: [4, 5, 6],
-                                   cycle_length: 7, block_length: 2})
-      for expected_element in self._interleave(
-          [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
-        self.assertEqual(expected_element, sess.run(next_element))
+    with self.test_session() as sess:
+      for value in input_values:
+        if np.isnan(value):
+          with self.assertRaises(errors.InvalidArgumentError):
+            sess.run(get_next)
+        else:
+          self.assertEqual(value, sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
-
-      # Empty input.
-      sess.run(init_op, feed_dict={input_values: [],
-                                   cycle_length: 2, block_length: 3})
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
-
-      # Non-empty input leading to empty output.
-      sess.run(init_op, feed_dict={input_values: [0, 0, 0],
-                                   cycle_length: 2, block_length: 3})
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
-
-      # Mixture of non-empty and empty interleaved datasets.
-      sess.run(init_op, feed_dict={input_values: [4, 0, 6],
-                                   cycle_length: 2, block_length: 3})
-      for expected_element in self._interleave(
-          [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
-        self.assertEqual(expected_element, sess.run(next_element))
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(next_element)
+        sess.run(get_next)
 
   def testSparse(self):
 
@@ -201,20 +204,6 @@
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
-  def testEmptyInput(self):
-    iterator = (
-        dataset_ops.Dataset.from_tensor_slices([])
-        .repeat(None)
-        .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
-        .make_initializable_iterator())
-    init_op = iterator.initializer
-    get_next = iterator.get_next()
-
-    with self.test_session() as sess:
-      sess.run(init_op)
-      with self.assertRaises(errors.OutOfRangeError):
-        sess.run(get_next)
-
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index b0414ad..671e5d4 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -91,7 +91,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(14):
         for i in range(7):
           result = sess.run(get_next)
@@ -117,7 +117,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(14):
         for i in range(7):
           result = sess.run(get_next)
@@ -208,7 +208,7 @@
     iterator = dataset.make_one_shot_iterator()
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
         sess.run(next_element)
 
@@ -216,7 +216,7 @@
       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
         sess.run(next_element)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       def consumer_thread():
         with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
@@ -287,7 +287,7 @@
         .make_initializable_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors.FailedPreconditionError,
                                    "iterator has not been initialized"):
         sess.run(get_next)
@@ -308,7 +308,7 @@
     self.assertEqual(dataset_4.output_types, iterator.output_types)
     self.assertEqual([None], iterator.output_shapes.as_list())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The iterator is initially uninitialized.
       with self.assertRaises(errors.FailedPreconditionError):
         sess.run(get_next)
@@ -380,7 +380,7 @@
     self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
     self.assertEqual([], feedable_iterator.output_shapes)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       iterator_3_handle = sess.run(iterator_3.string_handle())
       iterator_4_handle = sess.run(iterator_4.string_handle())
 
@@ -436,7 +436,7 @@
       self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
       self.assertEqual([], feedable_iterator.output_shapes)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         iterator_3_handle = sess.run(iterator_3.string_handle())
         iterator_4_handle = sess.run(iterator_4.string_handle())
 
@@ -524,7 +524,7 @@
     feedable_int_any = iterator_ops.Iterator.from_string_handle(
         handle_placeholder, dtypes.int32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       handle_int_scalar = sess.run(
           dataset_int_scalar.make_one_shot_iterator().string_handle())
       handle_float_vector = sess.run(
@@ -687,7 +687,7 @@
           f=_remote_fn,
           target=target_placeholder)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       elem = sess.run(
           remote_op,
           feed_dict={
@@ -803,16 +803,15 @@
     get_next = iterator.get_next if context.executing_eagerly(
     ) else functools.partial(self.evaluate, iterator.get_next())
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
-    with self.test_session() as sess:
-      self.assertAllEqual([1, 4], get_next())
-      save_path = checkpoint.save(checkpoint_prefix)
-      self.assertAllEqual([9, 16], get_next())
-      self.assertAllEqual([25, 36], get_next())
-      checkpoint.restore(save_path).run_restore_ops(sess)
-      self.assertAllEqual([9, 16], get_next())
-      self.assertAllEqual([25, 36], get_next())
-      with self.assertRaises(errors.OutOfRangeError):
-        get_next()
+    self.assertAllEqual([1, 4], get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([9, 16], get_next())
+    self.assertAllEqual([25, 36], get_next())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual([9, 16], get_next())
+    self.assertAllEqual([25, 36], get_next())
+    with self.assertRaises(errors.OutOfRangeError):
+      get_next()
 
   @test_util.run_in_graph_and_eager_modes
   def testSaveRestoreMultipleIterator(self):
@@ -833,19 +832,18 @@
     ) else functools.partial(self.evaluate, iterator_3.get_next())
     checkpoint = checkpointable_utils.Checkpoint(
         iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
-    with self.test_session() as sess:
-      self.assertAllEqual([1, 4], get_next_1())
-      self.assertAllEqual(0, get_next_3())
-      self.assertAllEqual(1, get_next_3())
-      self.assertAllEqual(2, get_next_3())
-      save_path = checkpoint.save(checkpoint_prefix)
-      self.assertAllEqual([1, 4], get_next_2())
-      self.assertAllEqual([9, 16], get_next_2())
-      self.assertAllEqual(3, get_next_3())
-      checkpoint.restore(save_path).run_restore_ops(sess)
-      self.assertAllEqual([9, 16], get_next_1())
-      self.assertAllEqual([1, 4], get_next_2())
-      self.assertAllEqual(3, get_next_3())
+    self.assertAllEqual([1, 4], get_next_1())
+    self.assertAllEqual(0, get_next_3())
+    self.assertAllEqual(1, get_next_3())
+    self.assertAllEqual(2, get_next_3())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual([1, 4], get_next_2())
+    self.assertAllEqual([9, 16], get_next_2())
+    self.assertAllEqual(3, get_next_3())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual([9, 16], get_next_1())
+    self.assertAllEqual([1, 4], get_next_2())
+    self.assertAllEqual(3, get_next_3())
 
   @test_util.run_in_graph_and_eager_modes
   def testRestoreExhaustedIterator(self):
@@ -856,17 +854,16 @@
     get_next = iterator.get_next if context.executing_eagerly(
     ) else functools.partial(self.evaluate, iterator.get_next())
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
-    with self.test_session() as sess:
-      self.assertAllEqual(0, get_next())
-      self.assertAllEqual(1, get_next())
-      save_path = checkpoint.save(checkpoint_prefix)
-      self.assertAllEqual(2, get_next())
-      checkpoint.restore(save_path).run_restore_ops(sess)
-      self.assertAllEqual(2, get_next())
-      save_path = checkpoint.save(checkpoint_prefix)
-      checkpoint.restore(save_path).run_restore_ops(sess)
-      with self.assertRaises(errors.OutOfRangeError):
-        get_next()
+    self.assertAllEqual(0, get_next())
+    self.assertAllEqual(1, get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    self.assertAllEqual(2, get_next())
+    checkpoint.restore(save_path).run_restore_ops()
+    self.assertAllEqual(2, get_next())
+    save_path = checkpoint.save(checkpoint_prefix)
+    checkpoint.restore(save_path).run_restore_ops()
+    with self.assertRaises(errors.OutOfRangeError):
+      get_next()
 
   def testRestoreInReconstructedIteratorInitializable(self):
     checkpoint_directory = self.get_temp_dir()
@@ -876,7 +873,7 @@
     get_next = iterator.get_next()
     checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
     for i in range(5):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         checkpoint.restore(checkpoint_management.latest_checkpoint(
             checkpoint_directory)).initialize_or_restore(sess)
         for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index 579096f..c4b338a 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -44,7 +44,7 @@
 
   def testEmptyDirectory(self):
     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_one_shot_iterator()
       next_element = itr.get_next()
       with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@
     self._touchTempFiles(filenames)
 
     dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_one_shot_iterator()
       next_element = itr.get_next()
 
@@ -75,7 +75,7 @@
 
     dataset = dataset_ops.Dataset.list_files(
         path.join(self.tmp_dir, '*'), shuffle=False)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_one_shot_iterator()
       next_element = itr.get_next()
 
@@ -91,7 +91,7 @@
 
     dataset = dataset_ops.Dataset.list_files(
         path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_initializable_iterator()
       next_element = itr.get_next()
 
@@ -121,7 +121,7 @@
     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_initializable_iterator()
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError, 'No files matched pattern: '):
@@ -136,7 +136,7 @@
     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_initializable_iterator()
       next_element = itr.get_next()
       sess.run(
@@ -162,7 +162,7 @@
     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_initializable_iterator()
       next_element = itr.get_next()
       sess.run(
@@ -187,7 +187,7 @@
     filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
     dataset = dataset_ops.Dataset.list_files(filename_placeholder)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_initializable_iterator()
       next_element = itr.get_next()
       sess.run(
@@ -221,7 +221,7 @@
     # more meaningful.
     dataset = dataset_ops.Dataset.list_files(
         path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       itr = dataset.make_one_shot_iterator()
       next_element = itr.get_next()
 
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 52b4320..7685d8d 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -22,6 +22,7 @@
 import time
 import warnings
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.core.framework import attr_value_pb2
@@ -46,7 +47,7 @@
 from tensorflow.python.platform import test
 
 
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test.TestCase, parameterized.TestCase):
 
   def _buildMapDataset(self, components, count):
     def _map_fn(x, y, z):
@@ -71,7 +72,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test single-threaded access to the iterator.
       sess.run(init_op, feed_dict={count: 14})
       for _ in range(14):
@@ -137,7 +138,8 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
+
       def do_test(num_parallel_calls_val, output_buffer_size_val):
         # Test single-threaded access to the iterator.
         sess.run(init_op, feed_dict={
@@ -202,7 +204,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(3):
         sess.run(get_next)
@@ -217,7 +219,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(3):
         sess.run(get_next)
@@ -232,7 +234,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(3):
         sess.run(get_next)
@@ -253,7 +255,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for _ in range(3):
         sess.run(get_next)
@@ -284,7 +286,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(table.init)
       sess.run(init_op)
       sess.run(get_next)
@@ -302,7 +304,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(enqueue_op)
       sess.run(close_op)
       sess.run(init_op)
@@ -327,7 +329,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(enqueue_op)
       sess.run(close_op)
       sess.run(init_op)
@@ -346,7 +348,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(counter_var.initializer)
       sess.run(init_op)
       for i in range(10):
@@ -366,7 +368,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.NotFoundError):
         sess.run(get_next)
@@ -378,7 +380,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       random_values = []
       with self.assertRaises(errors.OutOfRangeError):
@@ -403,7 +405,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         self.assertEqual(i * 2 + i ** 2, sess.run(get_next))
@@ -435,7 +437,7 @@
     next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next()
 
     # make sure both datasets contain the same data
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for i in range(count):
         tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
         self.assertEqual(tuple_, namedtuple_)
@@ -453,7 +455,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       self.assertAllEqual(row ** 2, sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
@@ -484,7 +486,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Simple test that prefetch yields the expected values in the
       # expected order.
       for buffer_size in [1, 10, 100, 1000]:
@@ -522,7 +524,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         self.assertEqual((i, 37.0), sess.run(get_next))
@@ -543,7 +545,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         self.assertEqual((i, 37.0), sess.run(get_next))
@@ -569,7 +571,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         actual = sess.run(get_next)
@@ -596,7 +598,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         actual = sess.run(get_next)
@@ -620,7 +622,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(100):
         self.assertEqual(i, sess.run(get_next))
@@ -634,7 +636,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       for i in range(10):
         self.assertEqual((i, b"hello", 10), sess.run(get_next))
@@ -701,67 +703,113 @@
     dataset = dataset.map(broken_function)
     iterator = dataset.make_initializable_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
         sess.run(iterator.initializer)
 
+# pylint: disable=g-long-lambda
+  @parameterized.named_parameters(
+      ("Map", lambda dataset, func:
+       dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)),
+      ("ParallelMap", lambda dataset, func:
+       dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1,
+                                      use_inter_op_parallelism=False)),
+  )
+  def testNoInterOpParallelism(self, make_dataset_fn):
+    dataset = dataset_ops.Dataset.from_tensors(0)
+
+    def _get_tid():
+      return np.int64(threading.current_thread().ident)
+
+    def _map_fn(_):
+      tids = []
+      for _ in range(10):
+        tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
+      return tids
+
+    dataset = make_dataset_fn(dataset, _map_fn)
+    iterator = dataset.make_one_shot_iterator()
+    get_next = iterator.get_next()
+
+    with self.test_session() as sess:
+      tids = sess.run(get_next)
+      self.assertTrue(all(tids[0] == tid for tid in tids))
+# pylint: enable=g-long-lambda
+
 
 class MapDatasetBenchmark(test.Benchmark):
 
   def benchmarkChainOfMaps(self):
     chain_lengths = [0, 1, 2, 5, 10, 20, 50]
     for chain_length in chain_lengths:
-      with ops.Graph().as_default():
-        dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
-        for _ in range(chain_length):
-          dataset = dataset.map(lambda x: x)
-        iterator = dataset.make_one_shot_iterator()
-        next_element = iterator.get_next()
+      for use_inter_op_parallelism in [False, True]:
+        with ops.Graph().as_default():
+          dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+          for _ in range(chain_length):
+            dataset = dataset_ops.MapDataset(
+                dataset,
+                lambda x: x,
+                use_inter_op_parallelism=use_inter_op_parallelism)
+          iterator = dataset.make_one_shot_iterator()
+          next_element = iterator.get_next()
 
-        with session.Session() as sess:
-          for _ in range(5):
-            sess.run(next_element.op)
-          deltas = []
-          for _ in range(100):
-            start = time.time()
-            for _ in range(100):
+          with session.Session() as sess:
+            for _ in range(5):
               sess.run(next_element.op)
-            end = time.time()
-            deltas.append(end - start)
+            deltas = []
+            for _ in range(100):
+              start = time.time()
+              for _ in range(100):
+                sess.run(next_element.op)
+              end = time.time()
+              deltas.append(end - start)
 
-          median_wall_time = np.median(deltas) / 100
-          print("Map dataset chain length: %d Median wall time: %f"
-                % (chain_length, median_wall_time))
-          self.report_benchmark(
-              iters=1000, wall_time=median_wall_time,
-              name="benchmark_map_dataset_chain_latency_%d" % chain_length)
+            median_wall_time = np.median(deltas) / 100
+            print("Map dataset chain length%s: %d Median wall time: %f" %
+                  (" (single threaded mode)" if not use_inter_op_parallelism
+                   else "", chain_length, median_wall_time))
+            self.report_benchmark(
+                iters=1000,
+                wall_time=median_wall_time,
+                name="benchmark_map_dataset_chain_latency_%d%s" %
+                (chain_length, "_single_threaded"
+                 if not use_inter_op_parallelism else ""))
 
   def benchmarkMapFanOut(self):
     fan_outs = [1, 2, 5, 10, 20, 50, 100]
     for fan_out in fan_outs:
-      with ops.Graph().as_default():
-        dataset = dataset_ops.Dataset.from_tensors(
-            tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs)
-        iterator = dataset.make_one_shot_iterator()
-        next_element = iterator.get_next()
+      for use_inter_op_parallelism in [False, True]:
+        with ops.Graph().as_default():
+          dataset = dataset_ops.Dataset.from_tensors(
+              tuple(0 for _ in range(fan_out))).repeat(None)
+          dataset = dataset_ops.MapDataset(
+              dataset,
+              lambda *xs: xs,
+              use_inter_op_parallelism=use_inter_op_parallelism)
+          iterator = dataset.make_one_shot_iterator()
+          next_element = iterator.get_next()
 
-        with session.Session() as sess:
-          for _ in range(5):
-            sess.run(next_element[0].op)
-          deltas = []
-          for _ in range(100):
-            start = time.time()
-            for _ in range(100):
+          with session.Session() as sess:
+            for _ in range(5):
               sess.run(next_element[0].op)
-            end = time.time()
-            deltas.append(end - start)
+            deltas = []
+            for _ in range(100):
+              start = time.time()
+              for _ in range(100):
+                sess.run(next_element[0].op)
+              end = time.time()
+              deltas.append(end - start)
 
-          median_wall_time = np.median(deltas) / 100
-          print("Map dataset fan out: %d Median wall time: %f"
-                % (fan_out, median_wall_time))
-          self.report_benchmark(
-              iters=1000, wall_time=median_wall_time,
-              name="benchmark_map_dataset_fan_out_%d" % fan_out)
+            median_wall_time = np.median(deltas) / 100
+            print("Map dataset fan out%s: %d Median wall time: %f" %
+                  (" (single threaded mode)" if not use_inter_op_parallelism
+                   else "", fan_out, median_wall_time))
+            self.report_benchmark(
+                iters=1000,
+                wall_time=median_wall_time,
+                name="benchmark_map_dataset_fan_out_%d%s" %
+                (fan_out, "_single_threaded"
+                 if not use_inter_op_parallelism else ""))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index a32527a..c344513 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -158,7 +158,7 @@
     self.assertEqual(ds.output_classes, next_elem.output_classes)
     elem_has_value_t = next_elem.has_value()
     elem_value_t = next_elem.get_value()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Before initializing the iterator, evaluating the optional fails with
       # a FailedPreconditionError.
       with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index 63a0830..cc97bac 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -36,7 +36,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
       for m in range(10):
         self.assertEqual(m, sess.run(get_next))
@@ -51,7 +51,7 @@
     init_op = iterator.initializer
 
     with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
 
 
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index ad87f31..51e9078 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -49,7 +49,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={stop: 5})
       for i in range(5):
         self.assertEqual(i, sess.run(get_next))
@@ -64,7 +64,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 2, stop: 5})
       for i in range(2, 5):
         self.assertEqual(i, sess.run(get_next))
@@ -80,7 +80,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
       for i in range(2, 10, 2):
         self.assertEqual(i, sess.run(get_next))
@@ -95,7 +95,7 @@
                                          step).make_initializable_iterator()
     init_op = iterator.initializer
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.InvalidArgumentError):
         sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
 
@@ -108,7 +108,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
       # This for loop is a no-op but will ensure that the implementation is
       # consistent with range if it ever changes.
@@ -125,7 +125,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 10, stop: 2})
       # This for loop is a no-op but will ensure that the implementation is
       # consistent with range if it ever changes.
@@ -143,7 +143,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
       # This for loop is a no-op but will ensure that the implementation is
       # consistent with range if it ever changes.
@@ -161,7 +161,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
       for i in range(10, 2, -1):
         self.assertEqual(i, sess.run(get_next))
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index 431362a..aa36363 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -100,7 +100,7 @@
     init_batch_op = iterator.make_initializer(batch_dataset)
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from file 0.
       sess.run(
           init_op, feed_dict={filenames: [test_filenames[0]],
@@ -163,7 +163,7 @@
     repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
     iterator = repeat_dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for j in range(2):
         for i in range(5):
           self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
@@ -240,7 +240,7 @@
     init_batch_op = iterator.make_initializer(batch_dataset)
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from file 0.
       sess.run(
           init_op, feed_dict={filenames: [test_filenames[0]],
@@ -302,7 +302,7 @@
         buffer_size=10)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for j in range(self._num_files):
         for i in range(self._num_records):
           self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
@@ -319,7 +319,7 @@
         buffer_size=10)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
@@ -661,7 +661,7 @@
     return filenames
 
   def testReadOneEpoch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Basic test: read from file 0.
       sess.run(
           self.init_op,
@@ -698,7 +698,7 @@
         sess.run(self.get_next)
 
   def testReadTenEpochs(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_op,
           feed_dict={self.filenames: self.test_filenames,
@@ -711,7 +711,7 @@
         sess.run(self.get_next)
 
   def testReadTenEpochsOfBatches(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_batch_op,
           feed_dict={
@@ -738,7 +738,7 @@
           f.write(cdata)
         zlib_files.append(zfn)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_op,
           feed_dict={self.filenames: zlib_files,
@@ -758,7 +758,7 @@
           gzf.write(f.read())
         gzip_files.append(gzfn)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(
           self.init_op,
           feed_dict={self.filenames: gzip_files,
@@ -774,7 +774,7 @@
     d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
     iterator = d.make_one_shot_iterator()
     next_element = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for j in range(self._num_files):
         for i in range(self._num_records):
           self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -786,7 +786,7 @@
     d = readers.TFRecordDataset(files)
     iterator = d.make_one_shot_iterator()
     next_element = iterator.get_next()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for j in range(self._num_files):
         for i in range(self._num_records):
           self.assertAllEqual(self._record(j, i), sess.run(next_element))
@@ -801,7 +801,7 @@
     next_element = iterator.get_next()
     expected = []
     actual = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(10):
         for j in range(self._num_files):
           for i in range(self._num_records):
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 1d27b03..37e2333 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -44,7 +44,7 @@
     self.assertEqual([c.shape for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test a finite repetition.
       sess.run(init_op, feed_dict={count_placeholder: 3})
       for _ in range(3):
@@ -90,7 +90,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Take fewer than input size
       sess.run(init_op, feed_dict={count_placeholder: 4})
       for i in range(4):
@@ -136,7 +136,7 @@
     self.assertEqual([c.shape[1:] for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Skip fewer than input size, we should skip
       # the first 4 elements and then read the rest.
       sess.run(init_op, feed_dict={count_placeholder: 4})
@@ -183,7 +183,7 @@
     self.assertEqual([c.shape for c in components],
                      [t.shape for t in get_next])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
       for _ in range(7 * 14):
         results = sess.run(get_next)
@@ -199,7 +199,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_op)
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index cefe872..137f634 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -28,7 +28,7 @@
     dataset = dataset_ops.Dataset.range(10).shard(5, 2)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(2, sess.run(iterator.get_next()))
       self.assertEqual(7, sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
@@ -40,7 +40,7 @@
     dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual((2, 8), sess.run(iterator.get_next()))
       self.assertEqual((7, 3), sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
@@ -50,7 +50,7 @@
     dataset = dataset_ops.Dataset.range(10).shard(5, 0)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(0, sess.run(iterator.get_next()))
       self.assertEqual(5, sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
@@ -76,14 +76,14 @@
     dataset = dataset_ops.Dataset.range(1).shard(5, 2)
     iterator = dataset.make_one_shot_iterator()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(iterator.get_next())
 
   def testLargerWorkerPool(self):
     dataset = dataset_ops.Dataset.range(10).shard(7, 5)
     iterator = dataset.make_one_shot_iterator()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(5, sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(iterator.get_next())
@@ -91,7 +91,7 @@
   def testIndexEqualsNumShards(self):
     dataset = dataset_ops.Dataset.range(10).shard(5, 4)
     iterator = dataset.make_one_shot_iterator()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(4, sess.run(iterator.get_next()))
       self.assertEqual(9, sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
@@ -100,7 +100,7 @@
   def testIndexEqualsNumShards2(self):
     dataset = dataset_ops.Dataset.range(10).shard(4, 3)
     iterator = dataset.make_one_shot_iterator()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(3, sess.run(iterator.get_next()))
       self.assertEqual(7, sess.run(iterator.get_next()))
       with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index 5fcc488..f294840 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -60,7 +60,7 @@
 
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # First run without shuffling to collect the "ground truth".
       sess.run(init_fifo_op)
       unshuffled_elements = []
@@ -140,7 +140,7 @@
     get_next = iterator.get_next()
 
     elems = []
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for _ in range(10):
         elems.append(sess.run(get_next))
       with self.assertRaises(errors.OutOfRangeError):
@@ -152,7 +152,7 @@
         .make_initializable_iterator())
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
       for elem in elems:
         self.assertEqual(elem, sess.run(get_next))
@@ -166,7 +166,7 @@
 
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       counts = collections.defaultdict(lambda: 0)
       for _ in range(10):
         for _ in range(5):
@@ -183,7 +183,7 @@
                 .make_one_shot_iterator())
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       initial_permutation = sess.run(next_element)
       self.assertAllEqual(initial_permutation, sess.run(next_element))
       self.assertAllEqual(initial_permutation, sess.run(next_element))
@@ -198,7 +198,7 @@
                 .make_one_shot_iterator())
     next_element = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       initial_permutation = list(sess.run(next_element))
       for _ in range(2):
         next_permutation = list(sess.run(next_element))
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 5593311..3106eff 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -45,7 +45,7 @@
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       equal_length_components = [
           np.tile(np.array([[1], [2], [3], [4]]), 20),
           np.tile(np.array([[12], [13], [14], [15]]), 22),
@@ -93,7 +93,7 @@
     self.assertEqual([22], get_next[1][0].shape)
     self.assertEqual([], get_next[1][1].shape)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       equal_length_components = [
           np.tile(np.array([[1], [2], [3], [4]]), 20),
           np.tile(np.array([[12], [13], [14], [15]]), 22),
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8c37b18..c985e00 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1019,7 +1019,11 @@
     """
     return FlatMapDataset(self, map_func)
 
-  def interleave(self, map_func, cycle_length, block_length=1):
+  def interleave(self,
+                 map_func,
+                 cycle_length,
+                 block_length=1,
+                 num_parallel_calls=None):
     """Maps `map_func` across this dataset, and interleaves the results.
 
     For example, you can use `Dataset.interleave()` to process many input files
@@ -1082,11 +1086,19 @@
         processed concurrently.
       block_length: The number of consecutive elements to produce from each
         input element before cycling to another input element.
+      num_parallel_calls: (Optional.) If specified, the implementation creates
+        a threadpool, which is used to fetch inputs from cycle elements
+        asynchronously and in parallel. The default behavior is to fetch inputs
+        from cycle elements synchronously with no parallelism.
 
     Returns:
       Dataset: A `Dataset`.
     """
-    return InterleaveDataset(self, map_func, cycle_length, block_length)
+    if num_parallel_calls is None:
+      return InterleaveDataset(self, map_func, cycle_length, block_length)
+    else:
+      return ParallelInterleaveDataset(self, map_func, cycle_length,
+                                       block_length, num_parallel_calls)
 
   def filter(self, predicate):
     """Filters this dataset according to `predicate`.
@@ -2207,10 +2219,11 @@
 class MapDataset(Dataset):
   """A `Dataset` that maps a function over elements in its input."""
 
-  def __init__(self, input_dataset, map_func):
+  def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
     """See `Dataset.map()` for details."""
     super(MapDataset, self).__init__()
     self._input_dataset = input_dataset
+    self._use_inter_op_parallelism = use_inter_op_parallelism
 
     wrapped_func = StructuredFunctionWrapper(
         map_func, "Dataset.map()", input_dataset)
@@ -2225,6 +2238,7 @@
         input_t,
         self._map_func.captured_inputs,
         f=self._map_func,
+        use_inter_op_parallelism=self._use_inter_op_parallelism,
         **flat_structure(self))
 
   @property
@@ -2243,9 +2257,14 @@
 class ParallelMapDataset(MapDataset):
   """A `Dataset` that maps a function over elements in its input in parallel."""
 
-  def __init__(self, input_dataset, map_func, num_parallel_calls):
+  def __init__(self,
+               input_dataset,
+               map_func,
+               num_parallel_calls,
+               use_inter_op_parallelism=True):
     """See `Dataset.map()` for details."""
-    super(ParallelMapDataset, self).__init__(input_dataset, map_func)
+    super(ParallelMapDataset, self).__init__(input_dataset, map_func,
+                                             use_inter_op_parallelism)
 
     self._num_parallel_calls = ops.convert_to_tensor(
         num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
@@ -2258,6 +2277,7 @@
         self._map_func.captured_inputs,
         f=self._map_func,
         num_parallel_calls=self._num_parallel_calls,
+        use_inter_op_parallelism=self._use_inter_op_parallelism,
         **flat_structure(self))
     # pylint: enable=protected-access
 
@@ -2328,6 +2348,36 @@
     return "Dataset.interleave()"
 
 
+class ParallelInterleaveDataset(FlatMapDataset):
+  """A `Dataset` that maps a function over its input and interleaves the result.
+
+  """
+
+  def __init__(self, input_dataset, map_func, cycle_length, block_length,
+               num_parallel_calls):
+    """See `Dataset.interleave()` for details."""
+    super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
+    self._cycle_length = ops.convert_to_tensor(
+        cycle_length, dtype=dtypes.int64, name="cycle_length")
+    self._block_length = ops.convert_to_tensor(
+        block_length, dtype=dtypes.int64, name="block_length")
+    self._num_parallel_calls = ops.convert_to_tensor(
+        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+
+  def _as_variant_tensor(self):
+    return gen_dataset_ops.parallel_interleave_dataset_v2(
+        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
+        self._map_func.captured_inputs,  # pylint: disable=protected-access
+        self._cycle_length,
+        self._block_length,
+        self._num_parallel_calls,
+        f=self._map_func,  # pylint: disable=protected-access
+        **flat_structure(self))
+
+  def _transformation_name(self):
+    return "Dataset.interleave()"
+
+
 class FilterDataset(Dataset):
   """A `Dataset` that filters its input according to a predicate function."""
 
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 6a67093..89c3afb 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -30,28 +30,28 @@
 
   def testInteger(self):
     resp = convert.optional_param_to_tensor("foo", 3)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(3, sess.run(resp))
 
   def testIntegerDefault(self):
     resp = convert.optional_param_to_tensor("foo", None)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(0, sess.run(resp))
 
   def testStringDefault(self):
     resp = convert.optional_param_to_tensor("bar", None, "default",
                                             dtypes.string)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(compat.as_bytes("default"), sess.run(resp))
 
   def testString(self):
     resp = convert.optional_param_to_tensor("bar", "value", "default",
                                             dtypes.string)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(compat.as_bytes("value"), sess.run(resp))
 
   def testPartialShapeToTensorKnownDimension(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
           tensor_shape.TensorShape([1]))))
       self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
@@ -60,7 +60,7 @@
           constant_op.constant([1], dtype=dtypes.int64))))
 
   def testPartialShapeToTensorUnknownDimension(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
           tensor_shape.TensorShape([None]))))
       self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
@@ -84,7 +84,7 @@
       convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
 
   def testPartialShapeToTensorMultipleDimensions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
           tensor_shape.TensorShape([3, 6]))))
       self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
@@ -113,7 +113,7 @@
           constant_op.constant([-1, -1], dtype=dtypes.int64))))
 
   def testPartialShapeToTensorScalar(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
           tensor_shape.TensorShape([]))))
       self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 9d621fc..e5abc65 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -96,37 +96,11 @@
       yield value
 
 
-def is_sequence(seq):
-  """Returns a true if `seq` is a Sequence or dict (except strings/lists).
+# See the swig file (../../util/util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequenceForData
 
-  NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
-  which *does* treat a Python list as a sequence. For ergonomic
-  reasons, `tf.data` users would prefer to treat lists as
-  implicit `tf.Tensor` objects, and dicts as (nested) sequences.
-
-  Args:
-    seq: an input sequence.
-
-  Returns:
-    True if the sequence is a not a string or list and is a
-    collections.Sequence.
-  """
-  return _pywrap_tensorflow.IsSequenceForData(seq)
-
-
-def flatten(nest):
-  """Returns a flat sequence from a given nested structure.
-
-  If `nest` is not a sequence, this returns a single-element list: `[nest]`.
-
-  Args:
-    nest: an arbitrarily nested structure or a scalar object.
-      Note, numpy arrays are considered scalars.
-
-  Returns:
-    A Python list, the flattened version of the input.
-  """
-  return _pywrap_tensorflow.FlattenForData(nest)
+# See the swig file (../../util/util.i) for documentation.
+flatten = _pywrap_tensorflow.FlattenForData
 
 
 def assert_same_structure(nest1, nest2, check_types=True):
diff --git a/tensorflow/python/data/util/sparse_test.py b/tensorflow/python/data/util/sparse_test.py
index d49b3ff..056b324 100644
--- a/tensorflow/python/data/util/sparse_test.py
+++ b/tensorflow/python/data/util/sparse_test.py
@@ -291,7 +291,7 @@
       self.assertEqual(a, b)
       return
     self.assertTrue(isinstance(b, sparse_tensor.SparseTensor))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(a.eval().indices, b.eval().indices)
       self.assertAllEqual(a.eval().values, b.eval().values)
       self.assertAllEqual(a.eval().dense_shape, b.eval().dense_shape)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 6f48d38..c1bc27d 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -241,7 +241,7 @@
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
-        "//tensorflow/python:gradients",
+        "//tensorflow/python:gradients_impl",
         "//tensorflow/python:graph_to_function_def",
         "//tensorflow/python:util",
         "//tensorflow/python/eager:context",
@@ -345,6 +345,7 @@
     deps = [
         ":backprop",
         ":context",
+        ":core",
         ":test",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:math_ops",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 9891068..be392c7 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -216,9 +216,7 @@
                        "function was being computed.")
 
     sources = [v.handle for v in variables]
-    grad = imperative_grad.imperative_grad(_default_vspace,
-                                           this_tape,
-                                           nest.flatten(end_node),
+    grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
                                            sources)
     return end_node, list(zip(grad, variables))
 
@@ -537,8 +535,8 @@
       if dy is not None:
         dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
       return imperative_grad.imperative_grad(
-          _default_vspace, this_tape, nest.flatten(result), sources,
-          output_gradients=dy)
+          this_tape, nest.flatten(result), sources, output_gradients=dy)
+
     return result, vjp
 
   return decorated
@@ -631,9 +629,9 @@
 _default_vspace = imperative_grad.VSpace(
     num_elements_fn=_num_elements,
     aggregate_fn=_aggregate_grads,
-    tensor_id=ops.tensor_id,
     zeros=_zeros,
     ones=_ones)
+pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
 
 
 def _handle_or_self(x):
@@ -695,19 +693,57 @@
   del g  # Drop the reference to the tape
   ```
 
+  By default GradientTape will automatically watch any trainable variables that
+  are accessed inside the context. If you want fine grained control over which
+  variables are watched you can disable automatic tracking by passing
+  `watch_accessed_variables=False` to the tape constructor:
+
+  ```python
+  with tf.GradientTape(watch_accessed_variables=False) as tape:
+    tape.watch(variable_a)
+    y = variable_a ** 2  # Gradients will be available for `variable_a`.
+    z = variable_b ** 3  # No gradients will be avaialble since `variable_b` is
+                         # not being watched.
+  ```
+
+  Note that when using models you should ensure that your variables exist when
+  using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
+  first iteration not have any gradients:
+
+  ```python
+  a = tf.keras.layers.Dense(32)
+  b = tf.keras.layers.Dense(32)
+
+  with tf.GradientTape(watch_accessed_variables=False) as tape:
+    tape.watch(a.variables)  # Since `a.build` has not been called at this point
+                             # `a.variables` will return an empty list and the
+                             # tape will not be watching anything.
+    result = b(a(inputs))
+    tape.gradient(result, a.variables)  # The result of this computation will be
+                                        # a list of `None`s since a's variables
+                                        # are not being watched.
+  ```
+
   Note that only tensors with real or complex dtypes are differentiable.
   """
 
-  def __init__(self, persistent=False):
+  def __init__(self, persistent=False, watch_accessed_variables=True):
     """Creates a new GradientTape.
 
     Args:
       persistent: Boolean controlling whether a persistent gradient tape
         is created. False by default, which means at most one call can
         be made to the gradient() method on this object.
+      watch_accessed_variables: Boolean controlling whether the tape will
+        automatically `watch` any (trainable) variables accessed while the tape
+        is active. Defaults to True meaning gradients can be requested from any
+        result computed in the tape derived from reading a trainable `Variable`.
+        If False users must explicitly `watch` any `Variable`s they want to
+        request gradients from.
     """
     self._tape = None
     self._persistent = persistent
+    self._watch_accessed_variables = watch_accessed_variables
     self._recording = False
     context.context().start_step()
 
@@ -721,15 +757,15 @@
     if self._recording:
       self._pop_tape()
 
-  def _push_tape(self, existing_tape=False):
+  def _push_tape(self):
     if self._recording:
       raise ValueError("Tape is already recording.")
-    if existing_tape:
-      if self._tape is None:
-        raise ValueError("There is no existing tape.")
-      tape.push_tape(self._tape)
+    if self._tape is None:
+      self._tape = tape.push_new_tape(
+          persistent=self._persistent,
+          watch_accessed_variables=self._watch_accessed_variables)
     else:
-      self._tape = tape.push_new_tape(persistent=self._persistent)
+      tape.push_tape(self._tape)
     self._recording = True
 
   def _pop_tape(self):
@@ -748,7 +784,13 @@
       tensor: a Tensor or list of Tensors.
     """
     for t in nest.flatten(tensor):
-      tape.watch(self._tape, _handle_or_self(t))
+      if hasattr(t, "handle"):
+        # There are many variable-like objects, all of them currently have
+        # `handle` attribute that points to a tensor. If this changes, internals
+        # of watch_variable need to change as well.
+        tape.watch_variable(self._tape, t)
+      else:
+        tape.watch(self._tape, t)
 
   @tf_contextlib.contextmanager
   def stop_recording(self):
@@ -780,7 +822,7 @@
     try:
       yield
     finally:
-      self._push_tape(existing_tape=True)
+      self._push_tape()
 
   def reset(self):
     """Clears all information stored in this tape.
@@ -814,6 +856,7 @@
     ```
     """
     self._pop_tape()
+    self._tape = None
     self._push_tape()
 
   def watched_variables(self):
@@ -865,7 +908,9 @@
                           for x in nest.flatten(output_gradients)]
 
     flat_grad = imperative_grad.imperative_grad(
-        _default_vspace, self._tape, nest.flatten(target), flat_sources,
+        self._tape,
+        nest.flatten(target),
+        flat_sources,
         output_gradients=output_gradients)
 
     if not self._persistent:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index caf36b6..f938ed5 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -64,7 +64,7 @@
     grad = backprop.gradients_function(fn, [0])(var)[0]
     grad = self.evaluate(ops.convert_to_tensor(grad))
 
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode():
       tf_var = array_ops.constant(var_np, dtypes.float32)
       tf_ind1 = array_ops.constant([0, 1])
       tf_ind2 = array_ops.constant([2, 3])
@@ -79,7 +79,7 @@
       tf_dense_grad = math_ops.unsorted_segment_sum(
           tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
 
-      self.assertAllClose(grad, tf_dense_grad.eval())
+      self.assertAllClose(grad, self.evaluate(tf_dense_grad))
 
   def testImplicitGradWithResourceVariable(self):
     x = resource_variable_ops.ResourceVariable(
@@ -198,7 +198,7 @@
     grad = backprop.implicit_grad(f)()[0][0]
     opt = training.GradientDescentOptimizer(lrn_rate)
 
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       tf_x = array_ops.ones((batch_size), dtypes.int64)
       # TODO(ashankar,apassos): Change to ResourceVariable.
       tf_embedding = variables.Variable(
@@ -474,6 +474,18 @@
     self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
 
   @test_util.assert_no_new_tensors
+  def testGradientTapeReEnterContext(self):
+    g = backprop.GradientTape()
+    with g:
+      x = constant_op.constant(3.0)
+      g.watch(x)
+      y = 2*x
+    with g:
+      z = 2*y
+    grad = g.gradient(target=z, sources=[x])
+    self.assertEqual(self.evaluate(grad), [4.0])
+
+  @test_util.assert_no_new_tensors
   @test_util.run_in_graph_and_eager_modes
   def testGradientTapeRepeatedSource(self):
     with backprop.GradientTape(persistent=False) as g:
@@ -941,7 +953,7 @@
   def testZerosCacheDoesntLeakAcrossGraphs(self):
     with context.graph_mode():
       def get_grad():
-        with ops.Graph().as_default(), self.test_session():
+        with ops.Graph().as_default(), self.cached_session():
           t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4))
           x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4))
           with backprop.GradientTape() as tape:
@@ -956,6 +968,60 @@
 
       self.assertAllEqual(grad1, grad2)
 
+  @test_util.run_in_graph_and_eager_modes
+  def testSelectivelyWatchVariables(self):
+    x1 = resource_variable_ops.ResourceVariable(1.0)
+    x2 = resource_variable_ops.ResourceVariable(1.0)
+    with backprop.GradientTape(watch_accessed_variables=False) as tape:
+      tape.watch(x2)
+      y = x1**2
+      z = x2**3
+    self.assertTupleEqual(tape.watched_variables(), (x2,))
+    dy, dz = tape.gradient([y, z], [x1, x2])
+    self.evaluate([x1.initializer, x2.initializer])
+    self.assertIsNone(dy)
+    self.assertEqual(self.evaluate(dz), 3.0)
+
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDifferentiatingScalarCache(self):
+    # In the following test, if x2 = x1 (i.e the objects are the exact same),
+    # then y is essentially, 2*x1, and dy/dx1 = 2.
+    # When we had a pure scalar cache in eager, this would be the case. This
+    # test prevents us from going back to that case.
+    with backprop.GradientTape(persistent=False) as g:
+      x1 = constant_op.constant(3.0)
+      x2 = constant_op.constant(3.0)
+      g.watch(x1)
+      g.watch(x2)
+      y = x1 + x2
+    grad = g.gradient(target=y, sources=[x1])
+    self.assertEqual(self.evaluate(grad), [1.0])
+
+  def testVariablesAndConstantsProduceTheSameGradients(self):
+
+    # In the following test, differentiating [y, z] against [a, b] gives:
+    # (dy/da + dz/da, dy/db + dz/db).
+    # If a and b are the same constant, dz/da will not be 0 (which it should
+    # be).
+    # This is solved by using variable since doing a read_value on a tensor will
+    # produce a new tensor and corresponding TensorHandle, and not reuse the
+    # same tensor (which would happen if we are using a cache and reusing
+    # EagerTensor objects).
+    def get_grads(a, b):
+      with backprop.GradientTape() as tape:
+        tape.watch([a, b])
+        y = a**3
+        z = b**2
+      return tape.gradient([y, z], [a, b])
+
+    gradients_constants = get_grads(
+        constant_op.constant(2.0), constant_op.constant(2.0))
+    gradients_variables = get_grads(
+        resource_variable_ops.ResourceVariable(2.0),
+        resource_variable_ops.ResourceVariable(2.0))
+    self.assertAllEqual(gradients_constants, gradients_variables)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index a2e8422..3fe79ef 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -42,6 +42,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import gen_math_ops
 from tensorflow.python.ops import math_ops
@@ -175,6 +176,11 @@
 
     self._run(func, 30000)
 
+  def benchmark_create_constant(self):
+    func = lambda: constant_op.constant(3.0)
+
+    self._run(func, 30000)
+
   def benchmark_create_float_tensor_from_list_CPU(self):
     self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
 
@@ -712,6 +718,25 @@
     assert np.equal(func(), make_keras_model()(data)).all()
     self._run(func, 30000)
 
+  def benchmarkScan(self):
+    elems = math_ops.range(1600)
+
+    def scan():
+      return functional_ops.scan(
+          lambda a, x: a + x, elems, parallel_iterations=1)
+
+    self._run(scan, 100)
+
+  def benchmarkScanDefun(self):
+    elems = math_ops.range(1600)
+
+    @function.defun
+    def scan():
+      return functional_ops.scan(
+          lambda a, x: a + x, elems, parallel_iterations=1)
+
+    self._run(scan, 100)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 6c87dcc..552ed29 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -27,6 +27,7 @@
 import numpy as np
 import six
 
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
@@ -34,6 +35,7 @@
 from tensorflow.python.eager import tape
 from tensorflow.python.eager.graph_only_ops import graph_placeholder
 from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes as dtypes_module
 from tensorflow.python.framework import ops
@@ -55,8 +57,15 @@
 # (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
 cond_v2_impl._function = sys.modules[__name__]  # pylint: disable=protected-access
 
+# This is to avoid a circular dependency with gradients_impl
+gradients_impl._function = sys.modules[__name__]  # pylint: disable=protected-access
 
-def create_substitute_placeholder(value, name, dtype=None):
+
+# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
+WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+
+
+def _create_substitute_placeholder(value, name, dtype=None):
   """Creates a placeholder for `value` and propagates shape info to it."""
   # Note: setting ops.control_dependencies(None) ensures we always put
   # capturing placeholders outside of any control flow context.
@@ -88,100 +97,6 @@
   return placeholder
 
 
-def capture_value(tensor_map, value, dtype, name):
-  """Capture a value from outside the function, to pass in as an extra arg."""
-  captured_value = tensor_map.get(value, None)
-  if captured_value is None:
-    captured_value = create_substitute_placeholder(value, name=name,
-                                                   dtype=dtype)
-    tensor_map[value] = captured_value
-  tape.record_operation("captured_value", [captured_value], [value],
-                        lambda x: [x])
-  return captured_value
-
-
-class CapturingGraph(ops.Graph):
-  """Graph that can capture tensors from other graphs.
-
-  Attributes:
-    captures: Maps external tensor -> internal tensor (e.g. input placeholder).
-      The entries are in the order they were captured.
-  """
-
-  def __init__(self):
-    super(CapturingGraph, self).__init__()
-
-    self.captures = collections.OrderedDict()
-    self._building_function = True
-
-    # Map from resource tensor name to last op (in program order) which uses
-    # this tensor. Used to enforce that execution order matches program order
-    # for resource tensors.
-    self._last_op_using_resource_tensor = {}
-
-  def clear_resource_control_flow_state(self):
-    self._last_op_using_resource_tensor = {}
-
-  # TODO(skyewm): get rid of name and use the name of `tensor`.
-  def capture(self, tensor, name=None):
-    """Capture `tensor` if it's external to this graph.
-
-    If `tensor` is from a different graph, returns a placeholder for it.
-    `tensor` and the placeholder will also appears in self.captures. Multiple
-    calls to this method with the same `tensor` argument will return the same
-    placeholder. If `tensor` is from this graph, returns `tensor`.
-
-    Args:
-      tensor: Tensor. May be from this FuncGraph or a different graph.
-      name: Optional name if a placeholder is created.
-
-    Returns:
-      Tensor from this FuncGraph.
-    """
-    if isinstance(tensor, ops.EagerTensor):
-      if name is None:
-        name = str(ops.uid())
-      return capture_value(self.captures, tensor, tensor.dtype, name)
-    if tensor.graph is not self:
-      if name is None:
-        name = tensor.op.name
-      return capture_value(self.captures, tensor, tensor.dtype, name)
-    return tensor
-
-  def create_op(
-      self,
-      op_type,
-      inputs,
-      dtypes,  # pylint: disable=redefined-outer-name
-      input_types=None,
-      name=None,
-      attrs=None,
-      op_def=None,
-      compute_shapes=True,
-      compute_device=True):
-    """Captures an external inputs before calling Graph.capture_op."""
-    # This capturing logic interacts poorly with control flow contexts which
-    # want to replace inputs of ops far too late in the process. This can lead
-    # the context to get confused and try to create an Enter for an Enter. We
-    # can detect this here and skip the additional Enter which can confuse loop
-    # validation logic.
-    if op_type == "Enter" and inputs[0].op.type == "Enter":
-      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
-        return inputs[0].op
-    # Calling AddValue on the control flow contexts to force creation of the
-    # backward accumulators in the original graph before we create placeholders
-    # to capture the inputs.
-    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
-    for i, inp in enumerate(inputs):
-      if ctxt is not None and hasattr(ctxt, "AddValue"):
-        inp = ctxt.AddValue(inp)
-      inp = self.capture(inp)
-      inputs[i] = inp
-    return super(CapturingGraph, self).create_op(
-        op_type, inputs, dtypes, input_types, name, attrs, op_def,
-        compute_device=compute_device)
-
-
 def _get_device_functions(ctx, graph):
   """Returns a tuple of device functions representing the device stack."""
   if ctx.executing_eagerly():
@@ -190,7 +105,45 @@
     return tuple(graph._device_functions_outer_to_inner)  # pylint: disable=protected-access
 
 
-class FuncGraph(CapturingGraph):
+def _parse_func_attrs(attributes):
+  """Convert the keyword arguments into function_def attributes.
+
+  Currently only support primitive types: bool, int, float and string.
+
+  Args:
+    attributes: the dictionary of attributes.
+  Returns:
+    A dict of attributes where the key is the name of attribute and the value
+      is the AttrValue proto.
+  Raises:
+    ValueError: If the kwargs contains unwhitelisted name or unsupported value
+      types.
+  """
+  attrs = {}
+  for key, value in attributes.items():
+    if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+      raise ValueError("Attribute name is not whitelisted. "
+                       "Whitelisted: prefix %s, got: %s" %
+                       (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+
+    if isinstance(value, attr_value_pb2.AttrValue):
+      attrs[key] = value
+    # bool type check has to happen before int since bool is a subclass of int.
+    elif isinstance(value, bool):
+      attrs[key] = attr_value_pb2.AttrValue(b=value)
+    elif isinstance(value, int):
+      attrs[key] = attr_value_pb2.AttrValue(i=value)
+    elif isinstance(value, float):
+      attrs[key] = attr_value_pb2.AttrValue(f=value)
+    elif isinstance(value, str):
+      attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+    else:
+      raise ValueError("Unsupported attribute type for %s with type %s" %
+                       (key, type(value)))
+  return attrs
+
+
+class FuncGraph(ops.Graph):
   """Graph representing a function body.
 
   Attributes:
@@ -207,6 +160,8 @@
     variables: Variables that should be watched during function execution.
     outer_graph: The graph this function is defined in. May be another FuncGraph
       or the global default Graph.
+    captures: Maps external tensor -> internal tensor (i.e. input placeholder).
+      The entries are in the order they were captured.
     seed: The graph-level random seed.
   """
 
@@ -227,6 +182,13 @@
     self.structured_outputs = None
     self.variables = []
     self.outer_graph = ops.get_default_graph()
+    self.captures = collections.OrderedDict()
+
+    self._building_function = True
+    # Map from resource tensor name to last op (in program order) which uses
+    # this tensor. Used to enforce that execution order matches program order
+    # for resource tensors.
+    self._last_op_using_resource_tensor = {}
 
     graph = self.outer_graph
 
@@ -255,15 +217,107 @@
     self._graph_key = graph._graph_key
     # pylint: enable=protected-access
 
+  def create_op(
+      self,
+      op_type,
+      inputs,
+      dtypes,
+      input_types=None,
+      name=None,
+      attrs=None,
+      op_def=None,
+      compute_shapes=True,
+      compute_device=True):
+    """Like Graph.create_op, except handles external input tensors.
+
+    This overload adds functionality to create_op to "capture" any external
+    input tensors, i.e. tensors from the eager context or outer function graphs
+    if this is a nested function. See `capture` for more information.
+
+    Args:
+      op_type: The `Operation` type to create. This corresponds to the
+        `OpDef.name` field for the proto that defines the operation.
+      inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
+      dtypes: A list of `DType` objects that will be the types of the tensors
+        that the operation produces.
+      input_types: (Optional.) A list of `DType`s that will be the types of
+        the tensors that the operation consumes. By default, uses the base
+        `DType` of each input in `inputs`. Operations that expect
+        reference-typed inputs must specify `input_types` explicitly.
+      name: (Optional.) A string name for the operation. If not specified, a
+        name is generated based on `op_type`.
+      attrs: (Optional.) A dictionary where the key is the attribute name (a
+        string) and the value is the respective `attr` attribute of the
+        `NodeDef` proto that will represent the operation (an `AttrValue`
+        proto).
+      op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
+        the operation will have.
+      compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+        computed).
+      compute_device: (Optional.) If True, device functions will be executed
+        to compute the device property of the Operation.
+
+    Returns:
+      An `Operation` object.
+    """
+    # This capturing logic interacts poorly with control flow contexts which
+    # want to replace inputs of ops far too late in the process. This can lead
+    # the context to get confused and try to create an Enter for an Enter. We
+    # can detect this here and skip the additional Enter which can confuse loop
+    # validation logic.
+    if op_type == "Enter" and inputs[0].op.type == "Enter":
+      if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
+        return inputs[0].op
+    # Calling AddValue on the control flow contexts to force creation of the
+    # backward accumulators in the original graph before we create placeholders
+    # to capture the inputs.
+    ctxt = ops.get_default_graph()._control_flow_context  # pylint: disable=protected-access
+    for i, inp in enumerate(inputs):
+      # TPU Estimator defines a control flow context with no AddValue method.
+      if ctxt is not None and hasattr(ctxt, "AddValue"):
+        inp = ctxt.AddValue(inp)
+      inp = self.capture(inp)
+      inputs[i] = inp
+    return super(FuncGraph, self).create_op(
+        op_type, inputs, dtypes, input_types, name, attrs, op_def,
+        compute_device=compute_device)
+
   def capture(self, tensor, name=None):
-    """Calls CapturingGraph.capture and updates self.inputs if necessary."""
-    new_capture = tensor not in self.captures
-    internal_tensor = super(FuncGraph, self).capture(tensor, name)
+    """Captures `tensor` if it's external to this graph.
 
-    if new_capture and tensor is not internal_tensor:
-      self.inputs.append(internal_tensor)
+    If `tensor` is from a different graph, returns a placeholder for it.
+    `tensor` and the placeholder will appear in self.captures, and the
+    placeholder will appear in self.inputs.  Multiple calls to this method with
+    the same `tensor` argument will return the same placeholder. If `tensor` is
+    from this graph, returns `tensor`.
 
-    return internal_tensor
+    Args:
+      tensor: Tensor. May be from this FuncGraph or a different graph.
+      name: Optional name if a placeholder is created.
+
+    Returns:
+      Tensor from this FuncGraph.
+    """
+    if isinstance(tensor, ops.EagerTensor):
+      if name is None:
+        name = str(ops.uid())
+      return self._capture_helper(tensor, name)
+    if tensor.graph is not self:
+      if name is None:
+        name = tensor.op.name
+      return self._capture_helper(tensor, name)
+    return tensor
+
+  def _capture_helper(self, tensor, name):
+    captured_tensor = self.captures.get(tensor, None)
+    if captured_tensor is None:
+      captured_tensor = _create_substitute_placeholder(tensor, name=name,
+                                                       dtype=tensor.dtype)
+      self.captures[tensor] = captured_tensor
+      self.inputs.append(captured_tensor)
+    tape.record_operation("captured_value", [captured_tensor], [tensor],
+                          lambda x: [x])
+    return captured_tensor
 
   @property
   def external_captures(self):
@@ -475,7 +529,7 @@
     self._num_outputs = len(self._func_graph.outputs)
     self._output_shapes = tuple(
         output.shape for output in self._func_graph.outputs)
-    self._attrs = attrs or {}
+    self._attrs = _parse_func_attrs(attrs)
     self._device_functions = tuple(
         self._func_graph._device_functions_outer_to_inner)  # pylint: disable=protected-access
 
@@ -509,7 +563,7 @@
 
     for v in self._func_graph.variables:
       if v.trainable:
-        tape.watch_variable(v)
+        tape.variable_accessed(v)
 
     captures = self._resolve_captured_inputs()
     tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
@@ -869,9 +923,6 @@
           _TensorType(arg.values.dtype, arg.values._shape_tuple()),
           _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
       ])
-  elif isinstance(arg, np.ndarray):
-    tensor = ops.convert_to_tensor(arg)
-    return _TensorType(tensor.dtype, tensor._shape_tuple())
   # pylint: enable=protected-access
   elif isinstance(arg, (list, tuple)):
     return tuple([_encode_arg(elem) for elem in arg])
@@ -901,7 +952,8 @@
   def __init__(self,
                python_function,
                name,
-               input_signature=None):
+               input_signature=None,
+               attributes=None):
     """Initializes a polymorphic function.
 
     Args:
@@ -910,6 +962,8 @@
       input_signature: a possibly nested sequence of `TensorSpec` objects
         specifying the input signature of this function. If `None`, a separate
         function is instantiated for each inferred input signature.
+      attributes: dict, extra keyword arguments that will be added as attribute
+         of the function.
 
     Raises:
       ValueError: if `input_signature` is not None and the `python_function`'s
@@ -927,6 +981,7 @@
     self._name = name
     self._function_cache = collections.OrderedDict()
     self._variables = []
+    self._function_attributes = attributes or {}
 
     self._lock = threading.Lock()
 
@@ -1079,6 +1134,17 @@
       # opposed to named arguments called in a keyword-like fashion.
       kwds.pop(arg)
     inputs = args + _deterministic_dict_values(arg_indices_to_values)
+    flat_inputs = nest.flatten(inputs)
+
+    # Check for NumPy arrays in arguments and convert them to Tensors.
+    need_packing = False
+    for index, value in enumerate(flat_inputs):
+      if isinstance(value, np.ndarray):
+        flat_inputs[index] = constant_op.constant(value)
+        need_packing = True
+    if need_packing:
+      inputs = nest.pack_sequence_as(structure=inputs,
+                                     flat_sequence=flat_inputs)
     if self._input_signature is None:
       return inputs, kwds
     else:
@@ -1088,7 +1154,6 @@
       except (ValueError, TypeError):
         raise ValueError("Structure of Python function inputs does not match "
                          "input_signature.")
-      flat_inputs = nest.flatten(inputs)
       if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
         raise ValueError("When input_signature is provided, all inputs to "
                          "the Python function must be Tensors.")
@@ -1131,13 +1196,42 @@
       if graph_function is None:
         graph_function = Function(
             func_graph_from_py_func(self._name, self._python_function, args,
-                                    kwds, self._input_signature))
+                                    kwds, self._input_signature),
+            self._function_attributes)
         self._variables.extend(
             [v for v in graph_function.variables if v not in self._variables])
         self._function_cache[cache_key] = graph_function
       return graph_function, (args, kwds)
 
 
+def register(func, *args, **kwargs):
+  """Register the defun function into the graph.
+
+  This won't actually call the function with the inputs, and only put the
+  function definition into graph. Register function with different input param
+  will result into multiple version of functions registered in graph.
+
+  Args:
+    func: the PolymorphicFunction instance that generated by a @defun
+    *args: input arguments for the Python function.
+    **kwargs: input keyword arguments for the Python function.
+
+  Returns:
+    a `Function` object specialized to inputs and execution context.
+
+  Raises:
+    ValueError: When the input function is not a defun wrapped python function.
+  """
+  if not isinstance(func, PolymorphicFunction):
+    raise ValueError("Only defun function is allowed to be registered. "
+                     "Got type: %s" % type(func))
+  concrete_func = func.get_concrete_function(*args, **kwargs)
+  graph = ops.get_default_graph()
+  concrete_func._inference_function.add_to_graph(graph)   # pylint: disable=protected-access
+  # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+  return concrete_func
+
+
 def _validate_signature(signature):
   if any(not isinstance(arg, tensor_spec.TensorSpec)
          for arg in nest.flatten(signature)):
@@ -1261,6 +1355,11 @@
   tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
   input signature inferred from `(*args, **kwargs)` and cached for future reuse.
 
+  NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
+  before being passed to `f`, and are treated as Tensors for caching. This
+  allows a function to be called multiple times with NumPy arrays having
+  different values but the same shape and dtype without re-tracing each time.
+
   `tf.contrib.eager.defun` caches graphs for your convenience, letting you
   define TensorFlow functions without explicitly specifying their signatures.
   However, this policy is conservative and potentially expensive; for example,
@@ -1460,7 +1559,29 @@
     TypeError: If `input_signature` is neither `None` nor a sequence of
       `tf.contrib.eager.TensorSpec` objects.
   """
+  return defun_with_attributes(func=func, input_signature=input_signature)
 
+
+def defun_with_attributes(func=None, input_signature=None, attributes=None):
+  """Compiles a Python function into a callable TensorFlow graph.
+
+  This function supports adding extra function attributes. See detailed
+  documentation in defun(). Currently this is not exposed in public API since we
+  don't expect user to directly use attributes, and attribute won't work by
+  itself. This assumption might change in future.
+
+  Args:
+    func: function to be compiled.
+    input_signature: same as defun()'s input_signature.
+    attributes: A dictionary of arguments which will be added to function def as
+      attributes. Currently only support primitive types as value, and only
+      whitelisted attribute name is allowed. Unwhitelisted attribute name or
+      unsupported value will result into ValueError.
+
+  Returns:
+    Same as the return value of defun, with attributes added to the function in
+    graph.
+  """
   if input_signature is not None:
     _validate_signature(input_signature)
 
@@ -1472,7 +1593,8 @@
       name = "function"
     return tf_decorator.make_decorator(
         function,
-        PolymorphicFunction(function, name, input_signature=input_signature))
+        PolymorphicFunction(function, name, input_signature=input_signature,
+                            attributes=attributes))
 
   # This code path is for the `foo = tfe.defun(foo, ...)` use case
   if func is not None:
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 3c79099..a0abefe 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -22,12 +22,14 @@
 from multiprocessing.pool import ThreadPool
 import sys
 
+import numpy
+
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -37,6 +39,7 @@
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
 from tensorflow.python.layers import convolutional
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
@@ -56,6 +59,21 @@
 from tensorflow.python.util import nest
 
 
+class MiniModel(keras_training.Model):
+  """Minimal model for mnist.
+
+  Useful for testing and debugging on slow TPU simulators.
+  """
+
+  def __init__(self):
+    super(MiniModel, self).__init__(name='')
+    self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+                                 bias_initializer='ones')
+
+  def call(self, inputs, training=True):
+    return self.fc(inputs)
+
+
 @test_util.with_c_shapes
 class FunctionTest(test.TestCase):
 
@@ -105,7 +123,7 @@
     self.assertAllEqual(step(), 2.0)
 
   def testGraphGradientVariable(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
 
       @function.defun
@@ -212,7 +230,7 @@
     self.assertAllEqual(f(), x)
 
   def testSymGradGatherNd(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
 
       @function.defun
       def f(x):
@@ -315,6 +333,7 @@
   def testDefunNumpyArraysConvertedToTensors(self):
 
     def f(x):
+      self.assertIsInstance(x, ops.Tensor)
       return x
 
     x = random_ops.random_uniform([2, 2]).numpy()
@@ -328,6 +347,12 @@
     # shouldn't trigger another function definition.
     self.assertEqual(len(defined._function_cache), 1)
 
+    # Test that the numpy array is properly an argument to the graph function.
+    self.assertEqual(1., defined(numpy.ones([])).numpy())
+    self.assertEqual(0., defined(numpy.zeros([])).numpy())
+    self.assertEqual(1., defined(array_ops.ones([])).numpy())
+    self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+
   def testDefunCapturedInt32(self):
     x = constant_op.constant(1, dtype=dtypes.int32)
 
@@ -482,7 +507,7 @@
     self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
 
   def testGraphModeCaptureVariable(self):
-    with context.graph_mode(), self.test_session() as sess:
+    with context.graph_mode(), self.cached_session() as sess:
 
       class HasAVar(object):
 
@@ -510,12 +535,12 @@
       x = constant_op.constant(1.0)
       l = f(x, v)
       _, dv = gradients_impl.gradients(l, [x, v])
-      with self.test_session():
+      with self.cached_session():
         v.initializer.run()
         self.assertAllEqual(dv.eval(), 0.0)
 
   def testGraphModeManyFunctions(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
 
       @function.defun
       def f(x):
@@ -616,7 +641,6 @@
 
     @function.defun
     def g(x):
-      tape.watch_variable(x)
       y = math_ops.add(x, three)
       f(y)
 
@@ -630,7 +654,6 @@
       return math_ops.add(x, three)
 
     def g(x):
-      tape.watch_variable(three)
       return f(x)
 
     g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
@@ -937,7 +960,7 @@
     self.assertEqual(1, int(read()))
 
   def testReturnCapturedGraphTensor(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       t = constant_op.constant(1)
 
       @function.defun
@@ -999,6 +1022,7 @@
       with ops.get_default_graph().as_default():
         create_variable()
 
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testLayerInDefun(self):
     conv = convolutional.Conv2D(
         filters=1,
@@ -1012,7 +1036,34 @@
 
     x = array_ops.ones([1, 2, 2, 1])
     y = model(x)
-    self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+    if not context.executing_eagerly():
+      self.evaluate(variables.global_variables_initializer())
+
+    self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+    # Remove reference cycles in model
+    test_util.dismantle_polymorphic_function(model)
+
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+  def testDefunKerasModelCall(self):
+    model = MiniModel()
+    model.call = function.defun(model.call)
+
+    x = array_ops.ones([1, 2])
+    y = model(x)
+
+    if not context.executing_eagerly():
+      self.evaluate(variables.global_variables_initializer())
+
+    self.assertAllEqual([[3.0]], self.evaluate(y))
+
+    # Remove reference cycles in defun.
+    test_util.dismantle_polymorphic_function(model.call)
+    # Break the reference cycle between the MiniModel and the defun:
+    # MiniModel --(through its `call` method)--> PolymorphicFunction
+    # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+    del model.call
 
   # Note: The ConfigProto below unfortunately only configures graph
   # construction. Eager's configuration is controlled in `__main__`.
@@ -1427,14 +1478,14 @@
     grad_t, = backprop.gradients_function(sq, [0])(t)
     self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
 
-    with backprop.GradientTape(persistent=True) as gtape:
-      gtape.watch(t)
+    with backprop.GradientTape(persistent=True) as tape:
+      tape.watch(t)
       one = matmul(t, b=t, transpose_a=True)
       two = matmul(b=t, a=t, transpose_a=True)
       three = matmul(a=t, b=t, transpose_a=True)
 
     for output in [one, two, three]:
-      self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+      self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
 
   def testGradientInFunctionWithKeywordArguments(self):
 
@@ -1495,12 +1546,151 @@
     side_effecting_function.python_function()
     self.assertAllEqual(state, [0, 0])
 
+  def testFunctionWithExtraAttributes(self):
+    @function.defun_with_attributes(attributes={'experimental_1': 'value1',
+                                                'experimental_2': 2})
+    def matmul(x, y):
+      return math_ops.matmul(x, y)
+
+    def add(x, y):
+      return math_ops.add(x, y)
+    defun_add = function.defun_with_attributes(
+        add, attributes={'experimental_3': True, 'experimental_4': 1.0})
+
+    with context.graph_mode(), self.test_session():
+      with ops.get_default_graph().as_default():
+        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+        sq = matmul(t, t)
+        double = defun_add(t, t)
+        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+
+        graph = ops.get_default_graph()
+        # pylint: disable=protected-access
+        self.assertEqual(len(graph._functions), 2)
+        functions = list(graph._functions.values())
+        self.assertRegexpMatches(
+            functions[0].definition.signature.name, '.*matmul.*')
+        attrs = functions[0].definition.attr
+        self.assertEqual(len(attrs), 2)
+        self.assertEqual(attrs['experimental_1'].s, b'value1')
+        self.assertEqual(attrs['experimental_2'].i, 2)
+
+        self.assertRegexpMatches(
+            functions[1].definition.signature.name, '.*add.*')
+        attrs = functions[1].definition.attr
+        self.assertEqual(len(attrs), 2)
+        self.assertEqual(attrs['experimental_3'].b, True)
+        self.assertEqual(attrs['experimental_4'].f, 1.0)
+        # pylint: enable=protected-access
+
+  def testFunctionWithInvalidAttribute(self):
+    @function.defun_with_attributes(attributes={'attr1': 'value1'})
+    def matmul(x, y):
+      return math_ops.matmul(x, y)
+
+    with self.assertRaisesRegexp(ValueError,
+                                 '.*Attribute name is not whitelisted.*'):
+      with context.graph_mode(), self.test_session():
+        with ops.get_default_graph().as_default():
+          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+          matmul(t, t)
+
+    @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
+    def add(x, y):
+      return math_ops.add(x, y)
+
+    with self.assertRaisesRegexp(ValueError,
+                                 '.*Unsupported attribute type.*'):
+      with context.graph_mode(), self.test_session():
+        with ops.get_default_graph().as_default():
+          t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+          add(t, t)
+
+  def testRegisterFunction(self):
+    @function.defun
+    def add(x, y):
+      return math_ops.add(x, y)
+
+    def matmul(x, y):
+      return math_ops.matmul(x, y)
+    defun_matmul = function.defun(matmul)
+
+    with context.graph_mode(), self.cached_session():
+      with ops.get_default_graph().as_default():
+        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+        function.register(defun_matmul, t, t)
+        function.register(add, t, t)
+
+        graph = ops.get_default_graph()
+        # pylint: disable=protected-access
+        self.assertEqual(len(graph._functions), 2)
+        functions = list(graph._functions.values())
+        pre_register_matmul_func_name = functions[0].definition.signature.name
+        self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
+        pre_register_add_func_name = functions[1].definition.signature.name
+        self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+
+        sq = defun_matmul(t, t)
+        double = add(t, t)
+        self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+        self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+        # Make sure the pre registered function is used, and no other function
+        # is added.
+        self.assertEqual(len(graph._functions), 2)
+        functions = list(graph._functions.values())
+        called_func_name = functions[0].definition.signature.name
+        self.assertEqual(pre_register_matmul_func_name, called_func_name)
+        called_func_name = functions[1].definition.signature.name
+        self.assertEqual(pre_register_add_func_name, called_func_name)
+
+  def testRegisterFunctionWithInputSignature(self):
+    def matmul(x, y):
+      return math_ops.matmul(x, y)
+    defun_matmul = function.defun(
+        matmul,
+        input_signature=[
+            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
+            tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
+        ])
+    with context.graph_mode(), self.cached_session():
+      with ops.get_default_graph().as_default():
+        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+        function.register(defun_matmul, t, t)
+
+        graph = ops.get_default_graph()
+        # pylint: disable=protected-access
+        self.assertEqual(len(graph._functions), 1)
+
+        # Test input param shape mismatch
+        t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+        with self.assertRaisesRegexp(
+            ValueError, 'Python inputs incompatible with input_signature'):
+          function.register(defun_matmul, t2, t2)
+
+  def testRegisterFunctionWithCache(self):
+    def matmul(x, y):
+      return math_ops.matmul(x, y)
+    defun_matmul = function.defun(matmul)
+
+    with context.graph_mode(), self.cached_session():
+      with ops.get_default_graph().as_default():
+        t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+        t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
+        function.register(defun_matmul, t, t)
+        function.register(defun_matmul, t2, t2)
+
+        graph = ops.get_default_graph()
+        # Only one function is registered since the input param are in same type
+        # pylint: disable=protected-access
+        self.assertEqual(len(graph._functions), 1)
+
 
 @test_util.with_c_shapes
 class AutomaticControlDependenciesTest(test.TestCase):
 
   def testBasic(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       with function.AutomaticControlDependencies() as c:
@@ -1511,7 +1701,7 @@
       self.assertAllEqual(val.eval(), 4.0)
 
   def testCondMustRun(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1532,7 +1722,7 @@
       self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0)
 
   def testCondMustRunSeparateRead(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1555,7 +1745,7 @@
       self.assertAllEqual(v.read_value().eval(), 6.0)
 
   def testCondNested(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1589,7 +1779,7 @@
       self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0)
 
   def testCondOneBranch(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1609,7 +1799,7 @@
       self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0)
 
   def testCondOneBranchUpdateBefore(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1630,7 +1820,7 @@
       self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0)
 
   def testCondOneBranchUpdateAfter(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
       p = array_ops.placeholder(dtype=dtypes.bool)
@@ -1666,7 +1856,7 @@
     self.assertAllEqual(out, [3, 4, 5])
 
   def testDecorator(self):
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
 
diff --git a/tensorflow/python/eager/graph_only_ops_test.py b/tensorflow/python/eager/graph_only_ops_test.py
index d2a2b4e..3cf3a61 100644
--- a/tensorflow/python/eager/graph_only_ops_test.py
+++ b/tensorflow/python/eager/graph_only_ops_test.py
@@ -32,13 +32,13 @@
   def testGraphZerosLike(self):
     x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
     z_tf = graph_only_ops.graph_zeros_like(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(np.zeros((2, 3)), z_tf.eval())
 
   def testGraphPlaceholder(self):
     x_tf = graph_only_ops.graph_placeholder(dtypes.int32, shape=(1,))
     y_tf = math_ops.square(x_tf)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = np.array([42])
       y = sess.run(y_tf, feed_dict={x_tf: np.array([42])})
       self.assertAllClose(np.square(x), y)
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 0001528..5f027d1 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -24,12 +24,10 @@
 
 
 VSpace = collections.namedtuple(
-    "VSpace",
-    ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
+    "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
 
 
 def imperative_grad(
-    vspace,
     tape,
     target,
     sources,
@@ -41,7 +39,6 @@
   gradients for all sources.
 
   Args:
-   vspace: the vector space in which to differentiate.
    tape: the gradient tape which stores the trace.
    target: either a Tensor or list of Tensors to be differentiated.
    sources: list of Tensors for which we want gradients
@@ -60,4 +57,7 @@
      computation of target.
   """
   return pywrap_tensorflow.TFE_Py_TapeGradient(
-      tape._tape, vspace, target, sources, output_gradients)  # pylint: disable=protected-access
+      tape._tape,  # pylint: disable=protected-access
+      target,
+      sources,
+      output_gradients)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 86fbd24..f34ce6a 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,6 +27,8 @@
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/python/lib/core/ndarray_tensor.h"
 
+#include "structmember.h"  // NOLINT // For PyMemberDef
+
 // forward declare
 struct EagerTensor;
 
@@ -325,12 +327,36 @@
   PyObject* context = nullptr;
   PyObject* device = nullptr;
   PyObject* dtype = Py_None;
-  const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
-  if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
+  PyObject* other_value = nullptr;
+  const char* kwlist[] = {"value", "context",     "device",
+                          "dtype", "other_value", nullptr};
+  if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
                                    const_cast<char**>(kwlist), &value, &context,
-                                   &device, &dtype)) {
+                                   &device, &dtype, &other_value)) {
     return -1;
   }
+
+  if (other_value != nullptr) {
+    if (!EagerTensor_CheckExact(other_value)) {
+      PyErr_SetString(PyExc_TypeError,
+                      tensorflow::strings::StrCat(
+                          "Expecting an EagerTensor for other_value, got ",
+                          Py_TYPE(other_value)->tp_name)
+                          .c_str());
+
+      return -1;
+    }
+    EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
+    self->handle =
+        TFE_TensorHandleCopySharingTensor(other->handle, self->status);
+
+    if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+      return -1;
+    }
+
+    return 0;
+  }
+
   // Extract dtype
   int desired_dtype = -1;
   if (dtype != Py_None) {
@@ -619,6 +645,15 @@
     {nullptr} /* Sentinel */
 };
 
+#if PY_MAJOR_VERSION < 3
+// Only used for Python2 since Python3 seems to set the __dict__ correctly.
+static PyMemberDef EagerTensor_members[] = {
+    {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
+     READONLY},
+    {nullptr},
+};
+#endif
+
 static PyMethodDef EagerTensor_methods[] = {
     {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
      PyDoc_STR("_numpy")},
@@ -693,7 +728,7 @@
     nullptr,                            /* tp_iter */
     nullptr,                            /* tp_iternext */
     EagerTensor_methods,                /* tp_methods */
-    nullptr,                            /* tp_members */
+    EagerTensor_members,                /* tp_members */
     EagerTensor_getseters,              /* tp_getset */
     nullptr,                            /* tp_base */
     nullptr,                            /* tp_dict */
@@ -829,7 +864,7 @@
   }
   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
 #else
-  _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
+  _EagerTensorType.tp_base = base_class_type;
 
   if (PyType_Ready(&_EagerTensorType) < 0) {
     if (PyErr_Occurred()) return nullptr;
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 16f8c3c..f1b4042 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,10 @@
 // This function is not thread-safe.
 PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
 
+// Registers e as the VSpace to use.
+// `vspace` must be a imperative_grad.py:VSpace named tuple.
+PyObject* TFE_Py_RegisterVSpace(PyObject* e);
+
 // Registers e as the Exception to be raised when the conditions of
 // TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
 // is a signal to the calling code that it should fall back to the safer (and
@@ -124,9 +128,10 @@
 // To unset the profiler, pass Py_None as the value of `profiler`.
 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
 
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+                            PyObject* watch_accessed_variables);
 
 // Removes the passed tape from the set of active tapes.
 void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -158,18 +163,20 @@
                                    PyObject* input_tensor_ids,
                                    PyObject* backward_function);
 
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
 // Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
 
 // Computes a gradient based on information recorded on the tape.`tape` must
-// have been produced by TFE_Py_NewTape. `vspace` must be a
-// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
 // lists of Tensor objects. `output_gradients` is either None or a python list
 // of either Tensor or None, and if not None should have the same length as
 // target.
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
-                              PyObject* target, PyObject* sources,
-                              PyObject* output_gradients, TF_Status* status);
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+                              PyObject* sources, PyObject* output_gradients,
+                              TF_Status* status);
 
 // Execute a tensorflow operation assuming that all provided inputs are
 // correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 0a33a04..9f2f4e0 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -892,9 +892,10 @@
 class GradientTape
     : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
  public:
-  explicit GradientTape(bool persistent)
+  explicit GradientTape(bool persistent, bool watch_accessed_variables)
       : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
-            persistent) {}
+            persistent),
+        watch_accessed_variables_(watch_accessed_variables) {}
 
   virtual ~GradientTape() {
     for (const IdAndVariable& v : watched_variables_) {
@@ -902,6 +903,12 @@
     }
   }
 
+  void VariableAccessed(PyObject* v) {
+    if (watch_accessed_variables_) {
+      WatchVariable(v);
+    }
+  }
+
   void WatchVariable(PyObject* v) {
     tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
     if (handle == nullptr) {
@@ -951,6 +958,7 @@
     }
   };
 
+  bool watch_accessed_variables_;
   tensorflow::mutex watched_variables_mu_;
   std::set<IdAndVariable, CompareById> watched_variables_
       GUARDED_BY(watched_variables_mu_);
@@ -1056,11 +1064,13 @@
 
 void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
 
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+                            PyObject* watch_accessed_variables) {
   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
-  tape->tape = new GradientTape(persistent == Py_True);
+  tape->tape = new GradientTape(persistent == Py_True,
+                                watch_accessed_variables == Py_True);
   Py_INCREF(tape);
   GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
   return reinterpret_cast<PyObject*>(tape);
@@ -1233,15 +1243,22 @@
   return list;
 }
 
-void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+void TFE_Py_TapeVariableAccessed(PyObject* variable) {
   if (*ThreadTapeIsStopped()) {
     return;
   }
   for (TFE_Py_Tape* tape : SafeTapeSet()) {
-    tape->tape->WatchVariable(variable);
+    tape->tape->VariableAccessed(variable);
   }
 }
 
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
+  if (*ThreadTapeIsStopped()) {
+    return;
+  }
+  reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
+}
+
 PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
   return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
 }
@@ -1348,7 +1365,9 @@
 class PyVSpace
     : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
  public:
-  explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+  explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+    Py_INCREF(py_vspace_);
+  }
 
   tensorflow::Status Initialize() {
     num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
@@ -1376,15 +1395,21 @@
     Py_XDECREF(aggregate_fn_);
     Py_XDECREF(zeros_);
     Py_XDECREF(ones_);
+
+    Py_DECREF(py_vspace_);
   }
 
   tensorflow::int64 NumElements(PyObject* tensor) const final {
     PyObject* arglist =
         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
     PyObject* result = PyEval_CallObject(num_elements_, arglist);
+    Py_DECREF(arglist);
+    if (result == nullptr) {
+      // The caller detects whether a python exception has been raised.
+      return -1;
+    }
     tensorflow::int64 r = MakeInt(result);
     Py_DECREF(result);
-    Py_DECREF(arglist);
     return r;
   }
 
@@ -1491,6 +1516,22 @@
   PyObject* zeros_;
   PyObject* ones_;
 };
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+  if (py_vspace != nullptr) {
+    delete py_vspace;
+  }
+
+  py_vspace = new PyVSpace(e);
+  auto status = py_vspace->Initialize();
+  if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+    delete py_vspace;
+    return nullptr;
+  }
+
+  Py_RETURN_NONE;
+}
 
 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -1507,9 +1548,9 @@
   return list;
 }
 
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
-                              PyObject* target, PyObject* sources,
-                              PyObject* output_gradients, TF_Status* status) {
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+                              PyObject* sources, PyObject* output_gradients,
+                              TF_Status* status) {
   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
   if (!tape_obj->tape->IsPersistent()) {
     auto* tape_set = GetTapeSet();
@@ -1524,10 +1565,6 @@
       return nullptr;
     }
   }
-  PyVSpace c_vspace(vspace);
-  if (!c_vspace.Initialize().ok()) {
-    return nullptr;
-  }
 
   std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
   if (PyErr_Occurred()) {
@@ -1551,7 +1588,7 @@
   }
   std::vector<PyObject*> result;
   status->status = tape_obj->tape->ComputeGradient(
-      c_vspace, target_vec, sources_vec, outgrad_vec, &result);
+      *py_vspace, target_vec, sources_vec, outgrad_vec, &result);
   if (!status->status.ok()) {
     if (PyErr_Occurred()) {
       // Do not propagate the erroneous status as that would swallow the
@@ -1707,117 +1744,167 @@
   Py_RETURN_NONE;
 }
 
-bool OpDoesntRequireOutput(const string& op_name) {
-  static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_outputs =
-      new tensorflow::gtl::FlatSet<string>({
-          "Identity",
-          "MatMul",
-          "Conv2DBackpropInput",
-          "Conv2DBackpropFilter",
-          "Conv3D",
-          "Conv3DBackpropInputV2",
-          "AvgPool3D",
-          "AvgPool3DGrad",
-          "MaxPool3D",
-          "MaxPool3DGrad",
-          "MaxPool3DGradGrad",
-          "BiasAdd",
-          "BiasAddV1",
-          "BiasAddGrad",
-          "Softplus",
-          "SoftplusGrad",
-          "Softsign",
-          "ReluGrad",
-          "Conv2D",
-          "DepthwiseConv2dNative",
-          "Dilation2D",
-          "AvgPool",
-          "AvgPoolGrad",
-          "BatchNormWithGlobalNormalization",
-          "L2Loss",
-          "Sum",
-          "Prod",
-          "SegmentSum",
-          "SegmentMean",
-          "SparseSegmentSum",
-          "SparseSegmentMean",
-          "SparseSegmentSqrtN",
-          "SegmentMin",
-          "SegmentMax",
-          "UnsortedSegmentSum",
-          "UnsortedSegmentMax",
-          "Abs",
-          "Neg",
-          "ReciprocalGrad",
-          "Square",
-          "Expm1",
-          "Log",
-          "Log1p",
-          "TanhGrad",
-          "SigmoidGrad",
-          "Sign",
-          "Sin",
-          "Cos",
-          "Tan",
-          "Add",
-          "Sub",
-          "Mul",
-          "Div",
-          "RealDiv",
-          "Maximum",
-          "Minimum",
-          "SquaredDifference",
-          "Select",
-          "SparseMatMul",
-          "BatchMatMul",
-          "Complex",
-          "Real",
-          "Imag",
-          "Angle",
-          "Conj",
-          "Cast",
-          "Cross",
-          "Cumsum",
-          "Cumprod",
-          "ReadVariableOp",
-          "VarHandleOp",
-          "Shape",
-          "StridedSlice",
+// Returns a pair where the first value of the pair indicates whether or not all
+// outputs are unused. If the first value is false, the second value is a
+// set that identifies which of the output indices are unused.
+bool OpGradientDoesntRequireOutputIndices(
+    const string& op_name,
+    std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+  static tensorflow::gtl::FlatMap<
+      string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+      new tensorflow::gtl::FlatMap<
+          string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+          // Ops that don't require any outputs.
+          {"Identity", {true, {}}},
+          {"MatMul", {true, {}}},
+          {"Conv2DBackpropInput", {true, {}}},
+          {"Conv2DBackpropFilter", {true, {}}},
+          {"Conv3D", {true, {}}},
+          {"Conv3DBackpropInputV2", {true, {}}},
+          {"AvgPool3D", {true, {}}},
+          {"AvgPool3DGrad", {true, {}}},
+          {"MaxPool3D", {true, {}}},
+          {"MaxPool3DGrad", {true, {}}},
+          {"MaxPool3DGradGrad", {true, {}}},
+          {"BiasAdd", {true, {}}},
+          {"BiasAddV1", {true, {}}},
+          {"BiasAddGrad", {true, {}}},
+          {"Softplus", {true, {}}},
+          {"SoftplusGrad", {true, {}}},
+          {"Softsign", {true, {}}},
+          {"ReluGrad", {true, {}}},
+          {"Conv2D", {true, {}}},
+          {"DepthwiseConv2dNative", {true, {}}},
+          {"Dilation2D", {true, {}}},
+          {"AvgPool", {true, {}}},
+          {"AvgPoolGrad", {true, {}}},
+          {"BatchNormWithGlobalNormalization", {true, {}}},
+          {"L2Loss", {true, {}}},
+          {"Sum", {true, {}}},
+          {"Prod", {true, {}}},
+          {"SegmentSum", {true, {}}},
+          {"SegmentMean", {true, {}}},
+          {"SparseSegmentSum", {true, {}}},
+          {"SparseSegmentMean", {true, {}}},
+          {"SparseSegmentSqrtN", {true, {}}},
+          {"SegmentMin", {true, {}}},
+          {"SegmentMax", {true, {}}},
+          {"UnsortedSegmentSum", {true, {}}},
+          {"UnsortedSegmentMax", {true, {}}},
+          {"Abs", {true, {}}},
+          {"Neg", {true, {}}},
+          {"ReciprocalGrad", {true, {}}},
+          {"Square", {true, {}}},
+          {"Expm1", {true, {}}},
+          {"Log", {true, {}}},
+          {"Log1p", {true, {}}},
+          {"TanhGrad", {true, {}}},
+          {"SigmoidGrad", {true, {}}},
+          {"Sign", {true, {}}},
+          {"Sin", {true, {}}},
+          {"Cos", {true, {}}},
+          {"Tan", {true, {}}},
+          {"Add", {true, {}}},
+          {"Sub", {true, {}}},
+          {"Mul", {true, {}}},
+          {"Div", {true, {}}},
+          {"RealDiv", {true, {}}},
+          {"Maximum", {true, {}}},
+          {"Minimum", {true, {}}},
+          {"SquaredDifference", {true, {}}},
+          {"Select", {true, {}}},
+          {"SparseMatMul", {true, {}}},
+          {"BatchMatMul", {true, {}}},
+          {"Complex", {true, {}}},
+          {"Real", {true, {}}},
+          {"Imag", {true, {}}},
+          {"Angle", {true, {}}},
+          {"Conj", {true, {}}},
+          {"Cast", {true, {}}},
+          {"Cross", {true, {}}},
+          {"Cumsum", {true, {}}},
+          {"Cumprod", {true, {}}},
+          {"ReadVariableOp", {true, {}}},
+          {"VarHandleOp", {true, {}}},
+          {"Shape", {true, {}}},
+          {"StridedSlice", {true, {}}},
+          {"Fill", {true, {}}},
+
+          // Ops that don't require a subset of outputs.
+          {"FusedBatchNorm", {false, {0, 1, 2}}},
       });
 
-  return ops_that_dont_require_outputs->find(op_name) !=
-         ops_that_dont_require_outputs->end();
+  auto it = m->find(op_name);
+
+  if (it == m->end()) return false;
+
+  *output = &it->second;
+  return true;
 }
 
-bool OpDoesntRequireInput(const string& op_name) {
-  static tensorflow::gtl::FlatSet<string>* ops_that_dont_require_inputs =
-      new tensorflow::gtl::FlatSet<string>({
-          "Identity",
-          "Softmax",
-          "LogSoftmax",
-          "BiasAdd",
-          "Relu",
-          "Relu6",
-          "Elu",
-          "Selu",
-          "SparseSoftmaxCrossEntropyWithLogits",
-          "Neg",
-          "Inv",
-          "Reciprocal",
-          "Sqrt",
-          "Exp",
-          "Tanh",
-          "Sigmoid",
-          "Real",
-          "Imag",
-          "Conj",
-          "ReadVariableOp",
-          "VarHandleOp",
-          "Shape",
+// Returns a pair where the first value of the pair indicates whether or not all
+// inputs are unused. If the first value is false, the second value is a
+// set that identifies which of the input indices are unused.
+bool OpGradientDoesntRequireInputIndices(
+    const string& op_name,
+    std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
+  static tensorflow::gtl::FlatMap<
+      string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
+      new tensorflow::gtl::FlatMap<
+          string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
+          // Ops that don't require any inputs.
+          {"Identity", {true, {}}},
+          {"Softmax", {true, {}}},
+          {"LogSoftmax", {true, {}}},
+          {"BiasAdd", {true, {}}},
+          {"Relu", {true, {}}},
+          {"Relu6", {true, {}}},
+          {"Elu", {true, {}}},
+          {"Selu", {true, {}}},
+          {"SparseSoftmaxCrossEntropyWithLogits", {true, {}}},
+          {"Neg", {true, {}}},
+          {"Inv", {true, {}}},
+          {"Reciprocal", {true, {}}},
+          {"Sqrt", {true, {}}},
+          {"Exp", {true, {}}},
+          {"Tanh", {true, {}}},
+          {"Sigmoid", {true, {}}},
+          {"Real", {true, {}}},
+          {"Imag", {true, {}}},
+          {"Conj", {true, {}}},
+          {"ReadVariableOp", {true, {}}},
+          {"VarHandleOp", {true, {}}},
+          {"Shape", {true, {}}},
+          {"Fill", {true, {}}},
+
+          // Ops that don't require a subset of inputs.
+          {"FusedBatchNorm", {false, {2}}},
       });
 
-  return ops_that_dont_require_inputs->find(op_name) !=
-         ops_that_dont_require_inputs->end();
+  auto it = m->find(op_name);
+
+  if (it == m->end()) return false;
+
+  *output = &it->second;
+  return true;
+}
+
+PyObject* CopySequenceSettingIndicesToNull(
+    PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
+  tensorflow::Safe_PyObjectPtr fast_seq(
+      PySequence_Fast(seq, "unable to allocate"));
+  PyObject* result = PyTuple_New(PySequence_Fast_GET_SIZE(fast_seq.get()));
+  for (int i = 0; i < PySequence_Fast_GET_SIZE(fast_seq.get()); i++) {
+    PyObject* item;
+    if (indices.find(i) != indices.end()) {
+      item = Py_None;
+    } else {
+      item = PySequence_Fast_GET_ITEM(fast_seq.get(), i);
+    }
+    Py_INCREF(item);
+    PyTuple_SET_ITEM(result, i, item);
+  }
+  return result;
 }
 
 PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
@@ -1837,16 +1924,35 @@
   if (!should_record) Py_RETURN_NONE;
 
   string c_op_name = TFE_GetPythonString(op_name);
+
   PyObject* op_outputs;
-  if (OpDoesntRequireOutput(c_op_name)) {
-    op_outputs = Py_None;
+  bool op_outputs_tuple_created = false;
+  std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
+
+  if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
+    if (outputs_not_required->first) {
+      op_outputs = Py_None;
+    } else {
+      op_outputs_tuple_created = true;
+      op_outputs = CopySequenceSettingIndicesToNull(
+          results, outputs_not_required->second);
+    }
   } else {
     op_outputs = results;
   }
 
   PyObject* op_inputs;
-  if (OpDoesntRequireInput(c_op_name)) {
-    op_inputs = Py_None;
+  bool op_inputs_tuple_created = false;
+  std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
+
+  if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
+    if (inputs_not_required->first) {
+      op_inputs = Py_None;
+    } else {
+      op_inputs_tuple_created = true;
+      op_inputs =
+          CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
+    }
   } else {
     op_inputs = inputs;
   }
@@ -1889,18 +1995,20 @@
       });
 
   Py_DECREF(num_inputs);
+  if (op_outputs_tuple_created) Py_DECREF(op_outputs);
+  if (op_inputs_tuple_created) Py_DECREF(op_inputs);
 
   Py_RETURN_NONE;
 }
 
-void MaybeWatchVariable(PyObject* input) {
+void MaybeNotifyVariableAccessed(PyObject* input) {
   DCHECK(CheckResourceVariable(input));
   DCHECK(PyObject_HasAttrString(input, "_trainable"));
 
   tensorflow::Safe_PyObjectPtr trainable(
       PyObject_GetAttrString(input, "_trainable"));
   if (trainable.get() == Py_False) return;
-  TFE_Py_TapeSetWatchVariable(input);
+  TFE_Py_TapeVariableAccessed(input);
 }
 
 bool CastTensor(const FastPathOpExecInfo& op_exec_info,
@@ -1931,7 +2039,7 @@
 bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
                     PyObject* input, tensorflow::Safe_PyObjectPtr* output,
                     TF_Status* status) {
-  MaybeWatchVariable(input);
+  MaybeNotifyVariableAccessed(input);
 
   TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
   auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
@@ -2459,13 +2567,18 @@
   int num_retvals = 0;
   for (int i = 0; i < op_def->output_arg_size(); i++) {
     const auto& output_arg = op_def->output_arg(i);
+    int delta = 1;
     if (!output_arg.number_attr().empty()) {
-      num_retvals += attr_list_sizes[output_arg.number_attr()];
+      delta = attr_list_sizes[output_arg.number_attr()];
     } else if (!output_arg.type_list_attr().empty()) {
-      num_retvals += attr_list_sizes[output_arg.type_list_attr()];
-    } else {
-      num_retvals++;
+      delta = attr_list_sizes[output_arg.type_list_attr()];
     }
+    if (delta < 0) {
+      RaiseFallbackException(
+          "Attributes suggest that the size of an output list is less than 0");
+      return nullptr;
+    }
+    num_retvals += delta;
   }
 
   tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index fd8ab69..669fa08 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,6 +21,7 @@
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import core
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -123,8 +124,8 @@
   def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
     ctx = context.context()
     with backprop.GradientTape(persistent=True) as tape:
-      a_2_by_2 = constant_op.constant(
-          [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+      a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]],
+                                      dtype=dtypes.float32)
       a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
       m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
       m2 = resource_variable_ops._MixedPrecisionVariable(
@@ -233,6 +234,26 @@
       pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
                                                ctx_handle, None, [], a_2_by_2)
 
+  @test_util.assert_no_new_tensors
+  @test_util.assert_no_garbage_created
+  def testFastPathExecute_InvalidAttributes(self):
+    split_dim = constant_op.constant(0, dtype=dtypes.int32)
+    value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
+    ctx = context.context()
+    ctx_handle = ctx._handle
+    with self.assertRaises(core._FallbackException):
+      pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+                                               "Split", None, None, split_dim,
+                                               value, "num_split", -1)
+
+  @test_util.assert_no_new_tensors
+  @test_util.assert_no_garbage_created
+  def testInvalidNumOutputs(self):
+    with self.assertRaisesRegexp(
+        Exception,
+        "Value for attr 'num_split' of -1 must be at least minimum 1"):
+      array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 6eb62af..399d902 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -33,9 +33,10 @@
     return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
 
 
-def push_new_tape(persistent=False):
+def push_new_tape(persistent=False, watch_accessed_variables=True):
   """Pushes a new tape onto the tape stack."""
-  tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+  tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
+                                             watch_accessed_variables)
   return Tape(tape)
 
 
@@ -49,13 +50,14 @@
   pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor)  # pylint: disable=protected-access
 
 
-def watch_variable(variable):
-  """Marks this variable to be watched by all tapes in the stack.
+def watch_variable(tape, variable):
+  """Marks this variable to be watched by the given tape."""
+  pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable)  # pylint: disable=protected-access
 
-  Args:
-    variable: variable to be watched.
-  """
-  pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
+
+def variable_accessed(variable):
+  """Notifies all tapes in the stack that a variable has been accessed."""
+  pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
 
 
 def pop_tape(tape):
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index 4326d5e..acd0e56 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -72,7 +72,7 @@
     a = constant_op.constant([[1., 0.], [0., 1.]])
     b = constant_op.constant([[1., 2.], [3., 4.]])
     da, db = backprop.gradients_function(fn, [0, 1])(a, b)
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
       tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
       tf_c = tf_a + tf_b
@@ -135,7 +135,7 @@
     a = constant_op.constant([[1., 0.], [0., 1.]])
     b = constant_op.constant([[1., 2.], [3., 4.]])
     da, db = backprop.gradients_function(fn, [0, 1])(a, b)
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       tf_a = constant_op.constant([[1, 0], [0, 1]], dtype=dtypes.float32)
       tf_b = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
       tf_mm = math_ops.matmul(tf_a, tf_b)
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 871136e..344a9b2 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
 
 
 def _create_tensor(value, device=None, dtype=None):
@@ -295,6 +296,7 @@
   def testFloatTensor(self):
     self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype)
     self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype)
+    self.assertEqual(dtypes.float16, _create_tensor(np.float16()).dtype)
     self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype)
 
   def testSliceDimOutOfRange(self):
@@ -332,6 +334,19 @@
         "but tensor at index 2 has rank 0"):
       pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
 
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testTensorDir(self):
+    t = array_ops.zeros(1)
+    t.test_attr = "Test"
+
+    instance_dir = dir(t)
+    type_dir = dir(ops.EagerTensor)
+
+    # Monkey patched attributes should show up in dir(t)
+    self.assertIn("test_attr", instance_dir)
+    instance_dir.remove("test_attr")
+    self.assertEqual(instance_dir, type_dir)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 9fce172..bfcc019 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -685,7 +685,7 @@
     srcs_version = "PY2AND3",
     tags = [
         "no_windows",
-        "notsan",
+        "notsan",  # b/67510291
     ],
     deps = [
         ":keras",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index d104c96..19f1801 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -1000,8 +1000,11 @@
     bucketized_feature_2 = bucketized_column(
       numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
 
+    # Need to see a large portion of the data before we can build a layer, for
+    # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
     classifier = estimator.BoostedTreesClassifier(
         feature_columns=[bucketized_feature_1, bucketized_feature_2],
+        n_batches_per_layer=n_batches_per_layer,
         n_trees=100,
         ... <some other params>
     )
@@ -1024,7 +1027,8 @@
         the model. All items in the set should be instances of classes derived
         from `FeatureColumn`.
       n_batches_per_layer: the number of batches to collect statistics per
-        layer.
+        layer. The total number of batches is total number of data divided by
+        batch size.
       model_dir: Directory to save model parameters, graph and etc. This can
         also be used to load checkpoints from the directory into a estimator
         to continue training a previously saved model.
@@ -1138,8 +1142,11 @@
     bucketized_feature_2 = bucketized_column(
       numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
 
+    # Need to see a large portion of the data before we can build a layer, for
+    # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
     regressor = estimator.BoostedTreesRegressor(
         feature_columns=[bucketized_feature_1, bucketized_feature_2],
+        n_batches_per_layer=n_batches_per_layer,
         n_trees=100,
         ... <some other params>
     )
@@ -1162,7 +1169,8 @@
         the model. All items in the set should be instances of classes derived
         from `FeatureColumn`.
       n_batches_per_layer: the number of batches to collect statistics per
-        layer.
+        layer. The total number of batches is total number of data divided by
+        batch size.
       model_dir: Directory to save model parameters, graph and etc. This can
         also be used to load checkpoints from the directory into a estimator
         to continue training a previously saved model.
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 08026a9..6e28c72 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -1560,7 +1560,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third = (
         self._get_expected_ensembles_for_classification())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train with train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1593,7 +1593,7 @@
     expected_first, expected_second, expected_third, expected_forth = (
         self._get_expected_ensembles_for_classification_with_bias())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
             boosted_trees._create_classification_head(n_classes=2),
@@ -1633,7 +1633,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third = (
         self._get_expected_ensembles_for_classification())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train without train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1666,7 +1666,7 @@
     expected_first, expected_second, expected_third, expected_forth = (
         self._get_expected_ensembles_for_classification_with_bias())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
             boosted_trees._create_classification_head(n_classes=2),
@@ -1704,7 +1704,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third = (
         self._get_expected_ensembles_for_regression())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train with train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1734,7 +1734,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third, expected_forth = (
         self._get_expected_ensembles_for_regression_with_bias())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train with train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1774,7 +1774,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third = (
         self._get_expected_ensembles_for_regression())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train without train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
@@ -1804,7 +1804,7 @@
     ops.reset_default_graph()
     expected_first, expected_second, expected_third, expected_forth = (
         self._get_expected_ensembles_for_regression_with_bias())
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Train with train_in_memory mode.
       with sess.graph.as_default():
         train_op, ensemble_serialized = self._get_train_op_and_ensemble(
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index c08cf61..1c0c458 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -142,7 +142,7 @@
                   dropout=None,
                   input_layer_partitioner=None,
                   config=None,
-                  tpu_estimator_spec=False,
+                  use_tpu=False,
                   batch_norm=False):
   """Deep Neural Net model_fn.
 
@@ -164,8 +164,8 @@
     input_layer_partitioner: Partitioner for input layer. Defaults
       to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
     config: `RunConfig` object to configure the runtime settings.
-    tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
-      or `model_fn.EstimatorSpec` instance.
+    use_tpu: Whether to make a DNN model able to run on TPU. Will make function
+      return a `_TPUEstimatorSpec` instance and disable variable partitioning.
     batch_norm: Whether to use batch normalization after each hidden layer.
 
   Returns:
@@ -182,13 +182,15 @@
       optimizer, learning_rate=_LEARNING_RATE)
   num_ps_replicas = config.num_ps_replicas if config else 0
 
-  partitioner = partitioned_variables.min_max_variable_partitioner(
-      max_partitions=num_ps_replicas)
+  partitioner = (None if use_tpu else
+                 partitioned_variables.min_max_variable_partitioner(
+                     max_partitions=num_ps_replicas))
   with variable_scope.variable_scope(
       'dnn',
       values=tuple(six.itervalues(features)),
       partitioner=partitioner):
     input_layer_partitioner = input_layer_partitioner or (
+        None if use_tpu else
         partitioned_variables.min_max_variable_partitioner(
             max_partitions=num_ps_replicas,
             min_slice_size=64 << 20))
@@ -203,7 +205,7 @@
         batch_norm=batch_norm)
     logits = logit_fn(features=features, mode=mode)
 
-    if tpu_estimator_spec:
+    if use_tpu:
       return head._create_tpu_estimator_spec(  # pylint: disable=protected-access
           features=features,
           mode=mode,
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index bd2e0ae..de9c84d 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -260,7 +260,7 @@
         features={'x': np.array(((30.,), (42.,),))},
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
         spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
             logits_placeholder: logits_2x2
@@ -293,7 +293,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -347,14 +347,14 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError('Labels must <= n_classes - 1'):
         training_loss.eval({
             labels_placeholder: labels_2x1_with_large_id,
             logits_placeholder: logits_2x3
         })
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError('Labels must >= 0'):
         training_loss.eval({
             labels_placeholder: labels_2x1_with_negative_id,
@@ -413,7 +413,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -449,7 +449,7 @@
         spec.export_outputs.keys())
 
     # Assert predictions and export_outputs.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       predictions = sess.run(spec.predictions)
@@ -484,7 +484,7 @@
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertAllEqual(
           expected_classes,
@@ -510,7 +510,7 @@
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       predictions = sess.run(spec.predictions)
       self.assertAllClose(logits,
@@ -534,7 +534,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -561,7 +561,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_input,
         labels=labels_input)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(np.sum(loss), actual_training_loss.eval())
 
@@ -581,7 +581,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -632,7 +632,7 @@
 
     # Assert predictions, loss, and metrics.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -698,7 +698,7 @@
 
     # Assert predictions, loss, and metrics.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -727,7 +727,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -755,7 +755,7 @@
     }
 
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -804,7 +804,7 @@
 
     # Assert loss, and metrics.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -837,7 +837,7 @@
         logits=logits,
         labels=labels)
     tol = 1e-2
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -866,7 +866,7 @@
         logits=logits,
         labels=labels)
     tol = 1e-2
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -921,7 +921,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -962,7 +962,7 @@
         optimizer=_Optimizer())
 
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -992,7 +992,7 @@
           labels=np.array(((1,), (1,)), dtype=np.int64),
           train_op_fn=_train_op_fn)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         _initialize_variables(self, spec.scaffold)
         sess.run(spec.train_op)
         w_value, t_value = sess.run([w, t])
@@ -1023,7 +1023,7 @@
 
     # Assert summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       summary_str = sess.run(spec.scaffold.summary_op)
@@ -1064,7 +1064,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1104,7 +1104,7 @@
         logits=logits,
         labels=labels_rank_1)
     tol = 1e-2
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1153,7 +1153,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1183,7 +1183,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1211,7 +1211,7 @@
         train_op_fn=_train_op_fn)
 
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss = sess.run(spec.loss)
       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1253,7 +1253,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1292,7 +1292,7 @@
         logits=logits,
         labels=labels)
     tol = 1e-2
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -1327,7 +1327,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1353,7 +1353,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -1380,7 +1380,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -1413,7 +1413,7 @@
 
     # Assert predictions, loss, and metrics.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -1506,7 +1506,7 @@
         features={'x': np.array(((42.,),))},
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
         spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval({
             logits_placeholder: logits_2x2
@@ -1536,7 +1536,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[2 2\]'):
@@ -1577,7 +1577,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[3 1\] \[labels_shape: \] \[2 1\]'):
@@ -1585,7 +1585,7 @@
             labels_placeholder: values_2x1,
             logits_placeholder: values_3x1
         })
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 1\] \[labels_shape: \] \[3 1\]'):
@@ -1624,7 +1624,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       predictions = sess.run(spec.predictions)
@@ -1660,7 +1660,7 @@
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertAllEqual(
           expected_classes,
@@ -1680,7 +1680,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1733,7 +1733,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1808,7 +1808,7 @@
     }
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1832,7 +1832,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(41., training_loss.eval())
 
@@ -1849,7 +1849,7 @@
         logits=logits,
         labels=labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1877,7 +1877,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -1924,7 +1924,7 @@
     }
     self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -1957,7 +1957,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -1983,7 +1983,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -2011,7 +2011,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_input,
         labels=labels_input)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(np.sum(loss), actual_training_loss.eval())
 
@@ -2031,7 +2031,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -2086,7 +2086,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2126,7 +2126,7 @@
         labels=labels,
         optimizer=_Optimizer())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAllClose(expected_loss, loss)
@@ -2153,7 +2153,7 @@
           labels=np.array(((1,), (1,),), dtype=np.float64),
           train_op_fn=_train_op_fn)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         _initialize_variables(self, spec.scaffold)
         sess.run(spec.train_op)
         w_value, t_value = sess.run([w, t])
@@ -2182,7 +2182,7 @@
         labels=labels,
         train_op_fn=_train_op_fn)
     # Assert summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       summary_str = sess.run(spec.scaffold.summary_op)
@@ -2227,7 +2227,7 @@
         regularization_losses=regularization_losses)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -2254,7 +2254,7 @@
     with self.assertRaisesRegexp(
         errors.InvalidArgumentError,
         r'Labels must <= n_classes - 1'):
-      with self.test_session():
+      with self.cached_session():
         _initialize_variables(self, monitored_session.Scaffold())
         training_loss.eval()
 
@@ -2277,7 +2277,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2309,7 +2309,7 @@
         train_op_fn=_train_op_fn)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)
@@ -2334,7 +2334,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
@@ -2360,7 +2360,7 @@
     expected_loss = 1.2484322
 
     # Assert loss.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2385,7 +2385,7 @@
         logits=logits)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       predictions = sess.run(spec.predictions)
       self.assertAllClose(
@@ -2447,7 +2447,7 @@
     self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2483,7 +2483,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels_rank_1)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(),
@@ -2531,7 +2531,7 @@
     self.assertIsNotNone(spec.train_op)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((
@@ -2577,7 +2577,7 @@
     self.assertIsNotNone(spec.train_op)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       loss, train_result, summary_str = sess.run((
@@ -2612,7 +2612,7 @@
         logits=logits,
         labels=labels)
     tol = 1e-2
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(
           expected_training_loss, training_loss.eval(),
@@ -2649,7 +2649,7 @@
 
     # Assert predictions, loss, train_op, and summaries.
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -2675,7 +2675,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -2700,7 +2700,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -2744,7 +2744,7 @@
     }
 
     tol = 1e-2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
@@ -2825,7 +2825,7 @@
         features={'x': np.array(((42.,),))},
         mode=model_fn.ModeKeys.PREDICT,
         logits=logits_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
         spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval({
             logits_placeholder: logits_1d
@@ -2857,7 +2857,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
         spec.loss.eval({
             labels_placeholder: values_3d,
@@ -2868,7 +2868,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2908,7 +2908,7 @@
         logits=logits_placeholder,
         labels=labels_placeholder,
         train_op_fn=lambda x: x)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.OpError, 'logits shape'):
         spec.loss.eval({
             labels_placeholder: values_3d,
@@ -2919,7 +2919,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits_placeholder,
         labels=labels_placeholder)[0]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           r'\[expected_labels_shape: \] \[2 3\] \[labels_shape: \] \[2 1\]'):
@@ -2957,7 +2957,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions.
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, spec.scaffold)
       self.assertAllClose(logits, spec.predictions[prediction_key].eval())
       self.assertAllClose(
@@ -2992,7 +2992,7 @@
         spec.export_outputs.keys())
 
     # Assert predictions.
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, spec.scaffold)
       self.assertAllClose(
           expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
@@ -3019,7 +3019,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       # loss = [(43-45)^2, (44-41)] = [4, 9]
       self.assertAllClose(13., training_loss.eval())
@@ -3045,7 +3045,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits_input,
         labels=labels_input)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(np.sum(loss), actual_training_loss.eval())
 
@@ -3064,7 +3064,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -3112,7 +3112,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3180,7 +3180,7 @@
     }
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -3212,7 +3212,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3237,7 +3237,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3294,7 +3294,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       predictions, loss, train_result, summary_str = sess.run((
@@ -3337,7 +3337,7 @@
         labels=labels,
         optimizer=_Optimizer())
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss, train_result = sess.run((spec.loss, spec.train_op))
       self.assertAllClose(expected_loss, loss)
@@ -3364,7 +3364,7 @@
           labels=np.array(((43.,), (44.,),), dtype=np.float64),
           train_op_fn=_train_op_fn)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         _initialize_variables(self, spec.scaffold)
         sess.run(spec.train_op)
         w_value, t_value = sess.run([w, t])
@@ -3394,7 +3394,7 @@
         train_op_fn=_train_op_fn)
 
     # Assert summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       summary_str = sess.run(spec.scaffold.summary_op)
@@ -3441,7 +3441,7 @@
         regularization_losses=regularization_losses)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       prediction_key = prediction_keys.PredictionKeys.PREDICTIONS
@@ -3487,7 +3487,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3523,7 +3523,7 @@
         labels=np.array(((35,), (42,), (45,)), dtype=np.int32))
 
     # Assert loss.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       loss = sess.run(spec.loss)
       # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6
@@ -3565,7 +3565,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       predictions, loss, train_result, summary_str = sess.run((
@@ -3600,7 +3600,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels_rank_1)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3648,7 +3648,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       predictions, loss, train_result, summary_str = sess.run((
@@ -3679,7 +3679,7 @@
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
       # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3718,7 +3718,7 @@
     _assert_no_hooks(self, spec)
 
     # Assert predictions, loss, and metrics.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNone(spec.scaffold.summary_op)
       loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[
@@ -3750,7 +3750,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)[0]
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].
       # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6
@@ -3796,7 +3796,7 @@
     _assert_no_hooks(self, spec)
 
     # Evaluate predictions, loss, train_op, and summaries.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
       predictions, loss, train_result, summary_str = sess.run((
@@ -3857,7 +3857,7 @@
     self.assertIsNone(spec.train_op)
     _assert_no_hooks(self, spec)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Finalize graph and initialize variables.
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3915,7 +3915,7 @@
     self.assertEqual(dtypes.float32, spec.loss.dtype)
     self.assertIsNotNone(spec.train_op)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Finalize graph and initialize variables.
       _initialize_variables(self, spec.scaffold)
       self.assertIsNotNone(spec.scaffold.summary_op)
@@ -3955,7 +3955,7 @@
         mode=model_fn.ModeKeys.TRAIN,
         logits=logits,
         labels=labels)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_training_loss, training_loss.eval())
       self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
@@ -3988,7 +3988,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_train_op_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       self.assertAllClose(expected_loss, spec.loss.eval())
 
@@ -4013,7 +4013,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
@@ -4042,7 +4042,7 @@
         logits=logits,
         labels=labels,
         train_op_fn=_no_op_train_fn)
-    with self.test_session():
+    with self.cached_session():
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index e44a69b..0f20ace 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -2056,7 +2056,7 @@
     var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
       `tf.estimator.VocabInfo`. The variable names should be "full" variables,
       not the names of the partitions.  If not explicitly provided, the variable
-      is assumed to have no vocabulary.
+      is assumed to have no (changes to) vocabulary.
     var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
       name of the previously-trained variable in `ckpt_to_initialize_from`. If
       not explicitly provided, the name of the variable is assumed to be same
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 4e7b00b..6329084 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -42,7 +42,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -28)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
       features, target = input_fn()
@@ -68,7 +68,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -30)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=128, shuffle=False, num_epochs=2)
       features, target = input_fn()
@@ -93,7 +93,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -28)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=0)
       features, target = input_fn()
@@ -114,7 +114,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -27)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
       features, target = input_fn()
@@ -150,7 +150,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -29)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=batch_size, shuffle=False, num_epochs=3)
       features, target = input_fn()
@@ -196,7 +196,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -28)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=batch_size, shuffle=False, num_epochs=1)
       features, target = input_fn()
@@ -221,7 +221,7 @@
     x = {'a': a, 'b': b}
     y = np.arange(-32, -30)
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
       features, target = input_fn()
@@ -240,7 +240,7 @@
   def testNumpyInputFnWithXAsNonDict(self):
     x = list(range(32, 36))
     y = np.arange(4)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):
         failing_input_fn = numpy_io.numpy_input_fn(
             x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -249,7 +249,7 @@
   def testNumpyInputFnWithXIsEmptyDict(self):
     x = {}
     y = np.arange(4)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
         failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
         failing_input_fn()
@@ -257,7 +257,7 @@
   def testNumpyInputFnWithXIsEmptyArray(self):
     x = np.array([[], []])
     y = np.arange(4)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):
         failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
         failing_input_fn()
@@ -268,7 +268,7 @@
     x = {'a': a, 'b': b}
     y = None
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
       features_tensor = input_fn()
@@ -291,7 +291,7 @@
   def testNumpyInputFnWithNonBoolShuffle(self):
     x = np.arange(32, 36)
     y = np.arange(4)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError,
                                    'shuffle must be provided and explicitly '
                                    'set as boolean'):
@@ -303,7 +303,7 @@
     x = {'__target_key__': array}
     y = np.arange(4)
 
-    with self.test_session():
+    with self.cached_session():
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
       input_fn()
@@ -318,7 +318,7 @@
     x_mismatch_length = {'a': np.arange(1), 'b': b}
     y_longer_length = np.arange(10)
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, 'Length of tensors in x and y is mismatched.'):
         failing_input_fn = numpy_io.numpy_input_fn(
@@ -341,7 +341,7 @@
     x = {'a': a, 'b': b}
     y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}
 
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = numpy_io.numpy_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
       features_tensor, targets_tensor = input_fn()
@@ -369,7 +369,7 @@
     b = np.arange(32, 36)
     x = {'a': a, 'b': b}
     y = {}
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):
         failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
         failing_input_fn()
@@ -379,7 +379,7 @@
     b = np.arange(32, 36)
     x = {'a': a, 'b': b}
     y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b}
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, '2 duplicate keys are found in both x and y'):
         failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index 6f13bc9..9e69fc7 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -102,7 +102,7 @@
   def testPandasInputFn_ProducesExpectedOutputs(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -116,7 +116,7 @@
   def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrameWithYAsDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -131,7 +131,7 @@
   def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrameWithYAsDataFrame()
       y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
       input_fn = pandas_io.pandas_input_fn(
@@ -147,7 +147,7 @@
   def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrameWithYAsDataFrame()
       y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
       input_fn = pandas_io.pandas_input_fn(
@@ -163,7 +163,7 @@
   def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       index = np.arange(100, 102)
       a = np.arange(2)
       b = np.arange(32, 34)
@@ -191,7 +191,7 @@
   def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       index = np.arange(100, 105)
       a = np.arange(5)
       b = np.arange(32, 37)
@@ -230,7 +230,7 @@
   def testPandasInputFn_OnlyX(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, _ = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -243,7 +243,7 @@
   def testPandasInputFn_ExcludesIndex(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -266,7 +266,7 @@
   def testPandasInputFn_RespectsEpoch_NoShuffle(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -276,7 +276,7 @@
   def testPandasInputFn_RespectsEpoch_WithShuffle(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -286,7 +286,7 @@
   def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
     if not HAS_PANDAS:
       return
-    with self.test_session() as session:
+    with self.cached_session() as session:
       x, y = self.makeTestDataFrame()
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -297,7 +297,7 @@
     if not HAS_PANDAS:
       return
     x, y = self.makeTestDataFrame()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       input_fn = pandas_io.pandas_input_fn(
           x, y, batch_size=3, shuffle=False, num_epochs=1)
 
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 290c460..3758243 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -26,20 +26,23 @@
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.estimator import keras as keras_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
 from tensorflow.python.estimator import run_config as run_config_lib
-from tensorflow.python.estimator.inputs import numpy_io
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.optimizers import SGD
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
 from tensorflow.python.ops.parsing_ops import gen_parsing_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.summary.writer import writer_cache
 from tensorflow.python.training import rmsprop
 from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
 
 
 try:
@@ -90,6 +93,58 @@
   return SimpleModel()
 
 
+def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
+  def input_fn():
+    ds = dataset_ops.Dataset.from_tensor_slices((x, y) if y is not None else x)
+    if shuffle:
+      ds = ds.shuffle(1000)
+    return ds.repeat(num_epochs).batch(batch_size)
+  return input_fn
+
+
+def get_multi_inputs_multi_outputs_data():
+  (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(16,),
+      num_classes=3,
+      random_seed=_RANDOM_SEED)
+  (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(16,),
+      num_classes=2,
+      random_seed=_RANDOM_SEED)
+  (m_train, _), (m_test, _) = testing_utils.get_test_data(
+      train_samples=_TRAIN_SIZE,
+      test_samples=50,
+      input_shape=(8,),
+      num_classes=2,
+      random_seed=_RANDOM_SEED)
+
+  c_train = keras.utils.to_categorical(c_train)
+  c_test = keras.utils.to_categorical(c_test)
+  d_train = keras.utils.to_categorical(d_train)
+  d_test = keras.utils.to_categorical(d_test)
+
+  train_data = {
+      'input_a': a_train,
+      'input_b': b_train,
+      'input_m': m_train,
+      'output_c': c_train,
+      'output_d': d_train
+  }
+  test_data = {
+      'input_a': a_test,
+      'input_b': b_test,
+      'input_m': m_test,
+      'output_c': c_test,
+      'output_d': d_test
+  }
+
+  return (train_data, test_data)
+
+
 def get_resource_for_simple_model(model_type='sequential',
                                   is_evaluate=False,):
   if model_type == 'sequential':
@@ -117,19 +172,19 @@
   y_train = keras.utils.to_categorical(y_train)
   y_test = keras.utils.to_categorical(y_test)
 
-  train_input_fn = numpy_io.numpy_input_fn(
+  train_input_fn = gen_input_fn(
       x=randomize_io_type(x_train, input_name),
       y=randomize_io_type(y_train, output_name),
       shuffle=False,
       num_epochs=None,
       batch_size=16)
 
-  evaluate_input_fn = numpy_io.numpy_input_fn(
+  evaluate_input_fn = gen_input_fn(
       x=randomize_io_type(x_test, input_name),
       y=randomize_io_type(y_test, output_name),
       num_epochs=1, shuffle=False)
 
-  predict_input_fn = numpy_io.numpy_input_fn(
+  predict_input_fn = gen_input_fn(
       x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
 
   inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
@@ -147,20 +202,21 @@
 
 
 def multi_inputs_multi_outputs_model():
-  a = keras.layers.Input(shape=(16,), name='input_a')
-  b = keras.layers.Input(shape=(16,), name='input_b')
-  m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+  input_a = keras.layers.Input(shape=(16,), name='input_a')
+  input_b = keras.layers.Input(shape=(16,), name='input_b')
+  input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
   dense = keras.layers.Dense(8, name='dense_1')
 
-  a_2 = dense(a)
+  interm_a = dense(input_a)
   # Read m
-  m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m)
-  s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2])
-  b_2 = dense(b)
-  merged = keras.layers.concatenate([s_2, b_2], name='merge')
-  c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
-  d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
-  model = keras.models.Model(inputs=[a, b, m], outputs=[c, d])
+  interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+  interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+  interm_b = dense(input_b)
+  merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+  output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+  output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+  model = keras.models.Model(
+      inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
   model.compile(
       loss='categorical_crossentropy',
       optimizer='rmsprop',
@@ -203,7 +259,7 @@
           optimizer='rmsprop',
           metrics=['mse', keras.metrics.categorical_accuracy])
 
-      with self.test_session():
+      with self.cached_session():
         est_keras = keras_lib.model_to_estimator(
             keras_model=keras_model, config=self._config)
         before_eval_results = est_keras.evaluate(
@@ -228,7 +284,7 @@
           metrics=['mse', keras.metrics.categorical_accuracy])
 
       my_hook = MyHook()
-      with self.test_session():
+      with self.cached_session():
         est_keras = keras_lib.model_to_estimator(
             keras_model=keras_model, config=self._config)
         before_eval_results = est_keras.evaluate(
@@ -252,7 +308,7 @@
         optimizer=rmsprop.RMSPropOptimizer(1e-3),
         metrics=['mse', keras.metrics.categorical_accuracy])
     my_hook = MyHook()
-    with self.test_session():
+    with self.cached_session():
       keras_model.fit(x_train, y_train, epochs=1)
 
       keras_est = keras_lib.model_to_estimator(
@@ -274,7 +330,7 @@
           optimizer=rmsprop.RMSPropOptimizer(1e-3),
           metrics=['mse', keras.metrics.categorical_accuracy])
 
-      with self.test_session():
+      with self.cached_session():
         est_keras = keras_lib.model_to_estimator(
             keras_model=keras_model,
             config=self._config)
@@ -297,7 +353,7 @@
         optimizer=rmsprop.RMSPropOptimizer(1e-3),
         metrics=['mse', keras.metrics.categorical_accuracy])
 
-    with self.test_session():
+    with self.cached_session():
       est_keras = keras_lib.model_to_estimator(
           keras_model=keras_model, config=self._config)
       est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
@@ -316,7 +372,7 @@
         optimizer=rmsprop.RMSPropOptimizer(1e-3),
         metrics=['mse', keras.metrics.categorical_accuracy])
 
-    with self.test_session():
+    with self.cached_session():
       # Create state
       keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
                                  np.random.random((10, _NUM_CLASS)))
@@ -343,7 +399,7 @@
         x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
             model_type='functional', is_evaluate=True)
 
-    with self.test_session():
+    with self.cached_session():
       metrics = [
           'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',
           'categorical_crossentropy', 'cosine_proximity', 'hinge',
@@ -357,7 +413,7 @@
       keras_model.fit(x_train, y_train, epochs=1)
       keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)
 
-    with self.test_session():
+    with self.cached_session():
       keras_est = keras_lib.model_to_estimator(
           keras_model=keras_model, config=self._config)
       est_eval = keras_est.evaluate(input_fn=eval_input_fn)
@@ -385,7 +441,7 @@
         x_test, _), _, pred_input_fn = get_resource_for_simple_model(
             model_type='sequential', is_evaluate=False)
 
-    with self.test_session():
+    with self.cached_session():
       keras_model.compile(
           loss='categorical_crossentropy',
           optimizer='adam',
@@ -393,7 +449,7 @@
       keras_model.fit(x_train, y_train, epochs=1)
       keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
 
-    with self.test_session():
+    with self.cached_session():
       keras_est = keras_lib.model_to_estimator(
           keras_model=keras_model, config=self._config)
       est_pred = [
@@ -402,51 +458,85 @@
       ]
     self.assertAllEqual(est_pred, keras_pred)
 
-  def test_multi_inputs_multi_outputs(self):
-    np.random.seed(_RANDOM_SEED)
-    (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
-        train_samples=_TRAIN_SIZE,
-        test_samples=50,
-        input_shape=(16,),
-        num_classes=3)
-    np.random.seed(_RANDOM_SEED)
-    (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
-        train_samples=_TRAIN_SIZE,
-        test_samples=50,
-        input_shape=(16,),
-        num_classes=2)
-    np.random.seed(_RANDOM_SEED)
-    (input_m_train, _), (input_m_test, _) = testing_utils.get_test_data(
-        train_samples=_TRAIN_SIZE,
-        test_samples=50,
-        input_shape=(8,),
-        num_classes=2)
-
-    c_train = keras.utils.to_categorical(c_train)
-    c_test = keras.utils.to_categorical(c_test)
-    d_train = keras.utils.to_categorical(d_train)
-    d_test = keras.utils.to_categorical(d_test)
+  def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+    train_data, test_data = get_multi_inputs_multi_outputs_data()
 
     def train_input_fn():
-      input_dict = {'input_a': a_train, 'input_b': b_train,
-                    'input_m': input_m_train.astype(np.str)}
-      output_dict = {'dense_2': c_train, 'dense_3': d_train}
+      input_dict = {
+          'input_a': train_data['input_a'],
+          'input_b': train_data['input_b'],
+          'input_m': train_data['input_m'].astype(np.str)
+      }
+      output_dict = {
+          'dense_2': train_data['output_c'],
+          'dense_3': train_data['output_d']
+      }
       return input_dict, output_dict
 
     def eval_input_fn():
-      input_dict = {'input_a': a_test, 'input_b': b_test,
-                    'input_m': input_m_test.astype(np.str)}
-      output_dict = {'dense_2': c_test, 'dense_3': d_test}
+      input_dict = {
+          'input_a': test_data['input_a'],
+          'input_b': test_data['input_b'],
+          'input_m': test_data['input_m'].astype(np.str)
+      }
+      output_dict = {
+          'dense_2': test_data['output_c'],
+          'dense_3': test_data['output_d']
+      }
       return input_dict, output_dict
 
-    with self.test_session():
+    def pred_input_fn():
+      input_dict = {
+          'input_a': test_data['input_a'],
+          'input_b': test_data['input_b'],
+          'input_m': test_data['input_m'].astype(np.str)
+      }
+      return input_dict
+
+    self.do_test_multi_inputs_multi_outputs_with_input_fn(
+        train_input_fn, eval_input_fn, pred_input_fn)
+
+  def test_multi_inputs_multi_outputs_with_input_fn_as_list(self):
+    train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+    def train_input_fn():
+      input_list = [
+          train_data['input_a'], train_data['input_b'],
+          train_data['input_m'].astype(np.str)
+      ]
+      output_list = [train_data['output_c'], train_data['output_d']]
+      return input_list, output_list
+
+    def eval_input_fn():
+      input_list = [
+          test_data['input_a'], test_data['input_b'],
+          test_data['input_m'].astype(np.str)
+      ]
+      output_list = [test_data['output_c'], test_data['output_d']]
+      return input_list, output_list
+
+    def pred_input_fn():
+      input_list = [
+          test_data['input_a'], test_data['input_b'],
+          test_data['input_m'].astype(np.str)
+      ]
+      return input_list
+
+    self.do_test_multi_inputs_multi_outputs_with_input_fn(
+        train_input_fn, eval_input_fn, pred_input_fn)
+
+  def do_test_multi_inputs_multi_outputs_with_input_fn(
+      self, train_input_fn, eval_input_fn, pred_input_fn):
+    with self.cached_session():
       model = multi_inputs_multi_outputs_model()
       est_keras = keras_lib.model_to_estimator(
           keras_model=model, config=self._config)
-      before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+      baseline_eval_results = est_keras.evaluate(
+          input_fn=eval_input_fn, steps=1)
       est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
-      after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
-      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+      eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+      self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+      est_keras.predict(input_fn=pred_input_fn)
 
   def test_init_from_file(self):
     if h5py is None:
@@ -456,7 +546,7 @@
         x_test, _), _, pred_input_fn = get_resource_for_simple_model(
             model_type='functional', is_evaluate=False)
 
-    with self.test_session():
+    with self.cached_session():
       keras_model.compile(
           loss='categorical_crossentropy',
           optimizer='rmsprop',
@@ -466,7 +556,7 @@
       fname = os.path.join(self._base_dir, 'keras_model.h5')
       keras.models.save_model(keras_model, fname)
 
-    with self.test_session():
+    with self.cached_session():
       keras_est = keras_lib.model_to_estimator(
           keras_model_path=fname, config=self._config)
       est_pred = [
@@ -479,19 +569,19 @@
     with self.assertRaisesRegexp(ValueError, 'Either'):
       keras_lib.model_to_estimator()
 
-    with self.test_session():
+    with self.cached_session():
       keras_model = simple_sequential_model()
       with self.assertRaisesRegexp(ValueError, 'not both'):
         keras_lib.model_to_estimator(
             keras_model=keras_model,
             keras_model_path=tempfile.mkdtemp(dir=self._base_dir))
 
-    with self.test_session():
+    with self.cached_session():
       keras_model = simple_sequential_model()
       with self.assertRaisesRegexp(ValueError, 'compiled'):
         keras_lib.model_to_estimator(keras_model=keras_model)
 
-    with self.test_session():
+    with self.cached_session():
       keras_model = simple_sequential_model()
       with self.assertRaisesRegexp(ValueError, 'not a local path'):
         keras_lib.model_to_estimator(
@@ -516,10 +606,10 @@
     model = simple_functional_model()
     model.compile(
         loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
-    with self.test_session():
+    with self.cached_session():
       est_keras = keras_lib.model_to_estimator(
           keras_model=model, config=self._config)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(KeyError,
                                    'Difference: .*invalid_input_name'):
         est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
@@ -547,20 +637,20 @@
     y_train = keras.utils.to_categorical(y_train, 2)
     input_name = keras_model.input_names[0]
     output_name = keras_model.output_names[0]
-    train_input_fn = numpy_io.numpy_input_fn(
+    train_input_fn = gen_input_fn(
         x=randomize_io_type(x_train, input_name),
         y=randomize_io_type(y_train, output_name),
         shuffle=False,
         num_epochs=None,
         batch_size=16)
     with self.assertRaisesRegexp(ValueError, 'relu6'):
-      with self.test_session():
+      with self.cached_session():
         est = keras_lib.model_to_estimator(
             keras_model=keras_model,
             model_dir=tempfile.mkdtemp(dir=self._base_dir))
         est.train(input_fn=train_input_fn, steps=1)
 
-    with self.test_session():
+    with self.cached_session():
       est = keras_lib.model_to_estimator(
           keras_model=keras_model,
           model_dir=tempfile.mkdtemp(dir=self._base_dir),
@@ -586,7 +676,7 @@
         }
     })
     with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
-      with self.test_session():
+      with self.cached_session():
         keras_lib.model_to_estimator(
             keras_model=keras_model,
             model_dir=tempfile.mkdtemp(dir=self._base_dir))
@@ -602,7 +692,7 @@
       gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
       sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
       self._config._session_config = sess_config
-      with self.test_session():
+      with self.cached_session():
         keras_lib.model_to_estimator(
             keras_model=keras_model, config=self._config)
         self.assertEqual(
@@ -618,7 +708,7 @@
         optimizer='rmsprop',
         metrics=['mse', keras.metrics.categorical_accuracy])
 
-    with self.test_session():
+    with self.cached_session():
       est_keras = keras_lib.model_to_estimator(
           keras_model=keras_model, model_dir=self._base_dir,
           config=run_config_lib.RunConfig())
@@ -629,7 +719,7 @@
       self.assertEqual(self._base_dir, est_keras._config.model_dir)
       self.assertEqual(self._base_dir, est_keras._model_dir)
 
-    with self.test_session():
+    with self.cached_session():
       est_keras = keras_lib.model_to_estimator(
           keras_model=keras_model, model_dir=self._base_dir,
           config=None)
@@ -648,7 +738,7 @@
         optimizer='rmsprop',
         metrics=['mse', keras.metrics.categorical_accuracy])
 
-    with self.test_session():
+    with self.cached_session():
       with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
         est_keras = keras_lib.model_to_estimator(
             keras_model=keras_model,
@@ -663,7 +753,7 @@
         optimizer='rmsprop',
         metrics=['mse', keras.metrics.categorical_accuracy])
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
                                    'constructor and `RunConfig`'):
         keras_lib.model_to_estimator(
@@ -676,7 +766,7 @@
         loss='categorical_crossentropy',
         optimizer=rmsprop.RMSPropOptimizer(1e-3),
         metrics=['mse', keras.metrics.categorical_accuracy])
-    with self.test_session():
+    with self.cached_session():
       keras_model.train_on_batch(
           np.random.random((10,) + _INPUT_SIZE),
           np.random.random((10, _NUM_CLASS)))
@@ -690,6 +780,32 @@
       keras_lib.model_to_estimator(
           keras_model=keras_model, config=self._config)
 
+  def assert_increasing_global_step(self, optimizer):
+    keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model(
+        model_type='sequential', is_evaluate=True)
+    keras_model.compile(
+        loss='categorical_crossentropy',
+        optimizer=optimizer,
+        metrics=['mse', keras.metrics.categorical_accuracy])
+    with self.cached_session() as sess:
+      keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
+      global_step = training_util.create_global_step()
+      features, labels = train_input_fn().make_one_shot_iterator().get_next()
+      spec = keras_model_fn(features, labels, mode=model_fn_lib.ModeKeys.TRAIN)
+
+      sess.run(variables.global_variables_initializer())
+      sess.run(variables.local_variables_initializer())
+
+      self.assertEqual(global_step.eval(), 0)  # Sanity check
+      sess.run(spec.train_op)
+      self.assertEqual(global_step.eval(), 1)
+
+  def test_model_fn_increments_global_step_tf_optimizer(self):
+    self.assert_increasing_global_step(rmsprop.RMSPropOptimizer(1e-3))
+
+  def test_model_fn_increments_global_step_keras_optimizer(self):
+    self.assert_increasing_global_step('rmsprop')
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index b1ca207..3773810 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -521,7 +521,12 @@
         eval_distribute=eval_distribute,
         experimental_distribute=experimental_distribute)
 
-    if train_distribute or eval_distribute or experimental_distribute:
+    # TODO(frankchn,priyag): Eventually use distributed coordinator for TPUs.
+    if ((train_distribute and
+         train_distribute.__class__.__name__ != 'TPUStrategy') or
+        (eval_distribute and
+         eval_distribute.__class__.__name__ != 'TPUStrategy') or
+        experimental_distribute):
       logging.info('Initializing RunConfig with distribution strategies.')
       distribute_coordinator_training.init_run_config(self, tf_config)
     else:
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 1017d4b..ac53a84 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -12,6 +12,7 @@
     srcs_version = "PY2AND3",
     deps = [
         ":feature_column",
+        ":feature_column_v2",
         "//tensorflow/python:util",
     ],
 )
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 2246d2f..9984379 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -169,7 +169,8 @@
                           weight_collections=None,
                           trainable=True,
                           cols_to_vars=None,
-                          scope=None):
+                          scope=None,
+                          cols_to_output_tensors=None):
   """See input_layer. `scope` is a name or variable scope to use."""
 
   feature_columns = _normalize_feature_columns(feature_columns)
@@ -202,14 +203,17 @@
             trainable=trainable)
         num_elements = column._variable_shape.num_elements()  # pylint: disable=protected-access
         batch_size = array_ops.shape(tensor)[0]
-        output_tensors.append(
-            array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+        output_tensor = array_ops.reshape(
+            tensor, shape=(batch_size, num_elements))
+        output_tensors.append(output_tensor)
         if cols_to_vars is not None:
           # Retrieve any variables created (some _DenseColumn's don't create
           # variables, in which case an empty list is returned).
           cols_to_vars[column] = ops.get_collection(
               ops.GraphKeys.GLOBAL_VARIABLES,
               scope=variable_scope.get_variable_scope().name)
+        if cols_to_output_tensors is not None:
+          cols_to_output_tensors[column] = output_tensor
     _verify_static_batch_size_equality(output_tensors, ordered_columns)
     return array_ops.concat(output_tensors, 1)
 
@@ -219,7 +223,8 @@
                 feature_columns,
                 weight_collections=None,
                 trainable=True,
-                cols_to_vars=None):
+                cols_to_vars=None,
+                cols_to_output_tensors=None):
   """Returns a dense `Tensor` as input layer based on given `feature_columns`.
 
   Generally a single example in training data is described with FeatureColumns.
@@ -264,6 +269,9 @@
         dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
                         <tf.Variable 'some_variable:1' shape=(5, 10)]}
       If a column creates no variables, its value will be an empty list.
+    cols_to_output_tensors: If not `None`, must be a dictionary that will be
+      filled with a mapping from '_FeatureColumn' to the associated
+      output `Tensor`s.
 
   Returns:
     A `Tensor` which represents input layer of a model. Its shape
@@ -273,8 +281,13 @@
   Raises:
     ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
   """
-  return _internal_input_layer(features, feature_columns, weight_collections,
-                               trainable, cols_to_vars)
+  return _internal_input_layer(
+      features,
+      feature_columns,
+      weight_collections=weight_collections,
+      trainable=trainable,
+      cols_to_vars=cols_to_vars,
+      cols_to_output_tensors=cols_to_output_tensors)
 
 
 # TODO(akshayka): InputLayer should be a subclass of Layer, and it
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 9b48223..abb79ef 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1637,6 +1637,40 @@
         self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
         self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
 
+  def test_fills_cols_to_output_tensors(self):
+    # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+    # _BucketizedColumn, and an _EmbeddingColumn.  Only the _EmbeddingColumn
+    # creates a Variable.
+    apple_numeric_column = fc.numeric_column('apple_numeric_column')
+    banana_dense_feature = fc.numeric_column('banana_dense_feature')
+    banana_dense_feature_bucketized = fc.bucketized_column(
+        banana_dense_feature, boundaries=[0.])
+    cherry_sparse_column = fc.categorical_column_with_hash_bucket(
+        'cherry_sparse_feature', hash_bucket_size=5)
+    dragonfruit_embedding_column = fc.embedding_column(
+        cherry_sparse_column, dimension=10)
+    with ops.Graph().as_default():
+      features = {
+          'apple_numeric_column': [[3.], [4.]],
+          'banana_dense_feature': [[-1.], [4.]],
+          'cherry_sparse_feature': [['a'], ['x']],
+      }
+      cols_to_output_tensors = {}
+      all_cols = [
+          apple_numeric_column, banana_dense_feature_bucketized,
+          dragonfruit_embedding_column
+      ]
+      input_layer = fc.input_layer(
+          features, all_cols, cols_to_output_tensors=cols_to_output_tensors)
+
+      # We check the mapping by checking that we have the right keys,
+      # and that the values (output_tensors) were indeed the ones used to
+      # form the input layer.
+      self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
+      input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
+      output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
+      self.assertItemsEqual(input_layer_inputs, output_tensors)
+
   def test_dense_collection(self):
     price = fc.numeric_column('price')
     with ops.Graph().as_default() as g:
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index aa66ed7..28c5c82 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -385,6 +385,10 @@
             'You can wrap a categorical column with an '
             'embedding_column or indicator_column. Given: {}'.format(column))
 
+  @property
+  def _is_feature_layer(self):
+    return True
+
   def build(self, _):
     for column in sorted(self._feature_columns, key=lambda x: x.name):
       if isinstance(column, SharedEmbeddingColumn):
@@ -409,7 +413,13 @@
       A `Tensor` which represents input layer of a model. Its shape
       is (batch_size, first_layer_dimension) and its dtype is `float32`.
       first_layer_dimension is determined based on given `feature_columns`.
+
+    Raises:
+      ValueError: If features are not a dictionary.
     """
+    if not isinstance(features, dict):
+      raise ValueError('We expected a dictionary here. Instead we got: ',
+                       features)
     transformation_cache = FeatureTransformationCache(features)
     output_tensors = []
     ordered_columns = []
@@ -431,6 +441,12 @@
     _verify_static_batch_size_equality(output_tensors, ordered_columns)
     return array_ops.concat(output_tensors, 1)
 
+  def compute_output_shape(self, input_shape):
+    total_elements = 0
+    for column in sorted(self._feature_columns, key=lambda x: x.name):
+      total_elements += column.variable_shape.num_elements()
+    return (input_shape[0], total_elements)
+
 
 def linear_model(features,
                  feature_columns,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 6b343ec..58168e0 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -2786,6 +2786,21 @@
       with _initialized_session():
         self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
 
+  def test_compute_output_shape(self):
+    price1 = fc.numeric_column('price1', shape=2)
+    price2 = fc.numeric_column('price2', shape=4)
+    with ops.Graph().as_default():
+      features = {
+          'price1': [[1., 2.], [5., 6.]],
+          'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+      }
+      feature_layer = FeatureLayer([price1, price2])
+      self.assertEqual((None, 6), feature_layer.compute_output_shape((None,)))
+      net = feature_layer(features)
+      with _initialized_session():
+        self.assertAllClose(
+            [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval())
+
   def test_raises_if_shape_mismatch(self):
     price = fc.numeric_column('price', shape=2)
     with ops.Graph().as_default():
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index eca34ac..4b2706d 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -105,7 +105,8 @@
     scalar_cache = ctx.scalar_cache()
     tensor = scalar_cache.get(cache_key, None)
     if tensor is not None:
-      return tensor
+      return ops.EagerTensor(
+          value, context=handle, device=device, dtype=dtype, other_value=tensor)
     t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
     scalar_cache[cache_key] = t
     return t
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
index a69018d..bc3c81b 100644
--- a/tensorflow/python/framework/error_interpolation.py
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -15,7 +15,7 @@
 """Function for interpolating formatted errors from the TensorFlow runtime.
 
 Exposes the function `interpolate` to interpolate messages with tags of the form
-^^type:name:format^^.
+{{type name}}.
 """
 
 from __future__ import absolute_import
@@ -32,9 +32,9 @@
 from tensorflow.python.util import tf_stack
 
 _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
-_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX)
+_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX)
 _INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
-_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
+_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX, re.DOTALL)
 
 _ParseTag = collections.namedtuple("_ParseTag", ["type", "name"])
 
@@ -48,8 +48,8 @@
   """Parses the message.
 
   Splits the message into separators and tags. Tags are named tuples
-  representing the string ^^type:name^^ and they are separated by
-  separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are
+  representing the string {{type name}} and they are separated by
+  separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are
   two tags and three separators. The separators are the numeric characters.
 
   Args:
@@ -58,7 +58,7 @@
   Returns:
     (list of separator strings, list of _ParseTags).
 
-    For example, if message is "123^^node:Foo^^456" then this function
+    For example, if message is "123{{node Foo}}456" then this function
     returns (["123", "456"], [_ParseTag("node", "Foo")])
   """
   seps = []
@@ -276,7 +276,7 @@
         message.
 
   Returns:
-    The string with tags of the form ^^type:name^^ interpolated.
+    The string with tags of the form {{type name}} interpolated.
   """
   seps, tags = _parse_message(error_message)
   subs = []
@@ -288,7 +288,7 @@
     except KeyError:
       op = None
 
-    msg = "^^%s:%s^^" % (t.type, t.name)
+    msg = "{{%s %s}}" % (t.type, t.name)
     if op is not None:
       field_dict = compute_field_dict(op)
       if t.type == "node":
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
index a7c7bbf..1b77548 100644
--- a/tensorflow/python/framework/error_interpolation_test.py
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -167,26 +167,31 @@
     self.assertEqual(interpolated_string, normal_string)
 
   def testOneTagWithAFakeNameResultsInPlaceholders(self):
-    one_tag_string = "^^node:MinusOne^^"
+    one_tag_string = "{{node MinusOne}}"
     interpolated_string = error_interpolation.interpolate(
         one_tag_string, self.graph)
     self.assertEqual(one_tag_string, interpolated_string)
 
   def testTwoTagsNoSeps(self):
-    two_tags_no_seps = "^^node:One^^^^node:Three^^"
+    two_tags_no_seps = "{{node One}}{{node Three}}"
     interpolated_string = error_interpolation.interpolate(
         two_tags_no_seps, self.graph)
     self.assertRegexpMatches(interpolated_string,
                              "constant_op.py:[0-9]+.*constant_op.py:[0-9]+")
 
   def testTwoTagsWithSeps(self):
-    two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;"
+    two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
     interpolated_string = error_interpolation.interpolate(
         two_tags_with_seps, self.graph)
     expected_regex = (
-        r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]*\) ;;;$")
+        r"^;;;.*constant_op.py:[0-9]+\) ,,,.*constant_op.py:[0-9]+\) ;;;$")
     self.assertRegexpMatches(interpolated_string, expected_regex)
 
+  def testNewLine(self):
+    newline = "\n\n{{node One}}"
+    interpolated_string = error_interpolation.interpolate(newline, self.graph)
+    self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*")
+
 
 class InterpolateDeviceSummaryTest(test.TestCase):
 
@@ -206,23 +211,23 @@
     self.graph = self.three.graph
 
   def testNodeZeroHasNoDeviceSummaryInfo(self):
-    message = "^^colocation_node:zero^^"
+    message = "{{colocation_node zero}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertIn("No device assignments were active", result)
 
   def testNodeOneHasExactlyOneInterpolatedDevice(self):
-    message = "^^colocation_node:one^^"
+    message = "{{colocation_node one}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertEqual(2, result.count("tf.device(/cpu)"))
 
   def testNodeTwoHasTwoInterpolatedDevice(self):
-    message = "^^colocation_node:two^^"
+    message = "{{colocation_node two}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertEqual(2, result.count("tf.device(/cpu)"))
     self.assertEqual(2, result.count("tf.device(/cpu:0)"))
 
   def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
-    message = "^^colocation_node:three^^"
+    message = "{{colocation_node three}}"
     result = error_interpolation.interpolate(message, self.graph)
     num_devices = result.count("tf.device")
     self.assertEqual(2, num_devices)
@@ -256,12 +261,12 @@
     self.graph = node_three.graph
 
   def testNodeThreeHasColocationInterpolation(self):
-    message = "^^colocation_node:Three_with_one^^"
+    message = "{{colocation_node Three_with_one}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertIn("colocate_with(One)", result)
 
   def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
-    message = "^^colocation_node:Four_with_three^^"
+    message = "{{colocation_node Four_with_three}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertIn("colocate_with(Three_with_one)", result)
     self.assertNotIn(
@@ -269,13 +274,13 @@
         "Node One should not appear in Four_with_three's summary:\n%s" % result)
 
   def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
-    message = "^^colocation_node:Five_with_one_with_two^^"
+    message = "{{colocation_node Five_with_one_with_two}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertIn("colocate_with(One)", result)
     self.assertIn("colocate_with(Two)", result)
 
   def testColocationInterpolationForNodeLackingColocation(self):
-    message = "^^colocation_node:One^^"
+    message = "{{colocation_node One}}"
     result = error_interpolation.interpolate(message, self.graph)
     self.assertIn("No node-device colocations", result)
     self.assertNotIn("Two", result)
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 5eb5914..6901715 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -37,7 +37,7 @@
     load_library.load_file_system_library(file_system_library)
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.WholeFileReader("test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       queue.enqueue_many([["test://foo"]]).run()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ee723ba..903768a 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -419,7 +419,7 @@
       with ops.control_dependencies([z]):
         return x * 2
 
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       z = Foo(constant_op.constant(3.0))
       self.assertAllEqual(z.eval(), 6.0)
 
@@ -434,7 +434,7 @@
     # Foo contains a stateful op (Assert).
     self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
     g = ops.Graph()
-    with g.as_default(), self.test_session():
+    with g.as_default(), self.cached_session():
       self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "assertion failed.*-3"):
@@ -448,7 +448,7 @@
           [control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
         return array_ops.identity(x)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(1.0, MyFn(1.0).eval())
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "assertion"):
@@ -667,7 +667,7 @@
 
     with ops.Graph().as_default():
       z = CubeXPlusY(3.0, -2.0)
-      with self.test_session():
+      with self.cached_session():
         self.assertAllEqual(z.eval(), 25.0)
 
   def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@
 
     with ops.Graph().as_default():
       z = CubeXPlusY(3.0, -2.0)
-      with self.test_session():
+      with self.cached_session():
         self.assertAllEqual(z.eval(), 25.0)
 
   def testUnusedFunction(self):
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 18e7d8a..2b4d8e7 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -396,7 +396,7 @@
 
       # Run the imported graph.
       # TODO(b/76173421): make this work (currently DCHECKS)
-      # with self.test_session() as sess:
+      # with self.cached_session() as sess:
       #   sess.run(imported_init)
       #   self.assertEqual(sess.run(imported_var), 1.0)
       #   self.assertEqual(sess.run(imported_assign), 2.0)
@@ -417,7 +417,7 @@
       imported_r, = importer.import_graph_def(graph_def,
                                               return_elements=[r.name])
       self.assertEqual(imported_r.name, "import/" + r.name)
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         self.assertEqual(sess.run(imported_r), 10)
 
   def testImportWhileLoopInCond(self):
@@ -436,7 +436,7 @@
       pred = array_ops.placeholder(dtypes.bool)
       out = control_flow_ops.cond(pred, ImportFn,
                                   lambda: constant_op.constant(1))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         self.assertEqual(sess.run(out, {pred: True}), 10)
         self.assertEqual(sess.run(out, {pred: False}), 1)
 
@@ -457,7 +457,7 @@
       out = control_flow_ops.while_loop(
           lambda i: i < 2, ImportFn, [0],
           shape_invariants=[tensor_shape.TensorShape(None)])
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         self.assertEqual(sess.run(out), 10)
 
   def testTypeMismatchInGraphDef(self):
@@ -929,7 +929,7 @@
           input_map={"a:0": constant_op.constant(5.0)},
           name="",
           return_elements=["id:0"])
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(5.0, t.eval())
 
   def testInvalidInputForReturnOperations(self):
@@ -958,7 +958,7 @@
       array_ops.stack([c, c], name="pack")
     gdef = g.as_graph_def()
 
-    with self.test_session():
+    with self.cached_session():
       pack, = importer.import_graph_def(gdef, return_elements=["pack"])
       self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
 
@@ -1063,7 +1063,7 @@
       self.assertEqual([10], biases_grad.get_shape())
 
   def testLargeGraph(self):
-    with self.test_session():
+    with self.cached_session():
       # The default message byte limit is 64M. Ours is 2G with a warning at 512.
       # Adding a 130M entries float32 tensor should exceed the warning, but not
       # the hard limit.
@@ -1254,7 +1254,7 @@
 
     z = TestFunc()
 
-    with self.test_session():
+    with self.cached_session():
       z_val = z.eval()
       self.assertEqual(z_val, -2.0)
 
@@ -1284,7 +1284,7 @@
       z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                      input_map=input_map)[0]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       z1_val, z2_val = sess.run((z1, z2))
       self.assertAllEqual(z1_val, z2_val)
 
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 6e5f7aa..fc98b91 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -117,7 +117,7 @@
       self.assertEqual(new_output_value, output_value)
 
   def testStrippedOpListNestedFunctions(self):
-    with self.test_session():
+    with self.cached_session():
       # Square two levels deep
       @function.Defun(dtypes.int32)
       def f0(x):
@@ -169,7 +169,7 @@
     # and "Tout" maps to complex64. Since these attr values map to their
     # defaults, they must be stripped unless stripping of default attrs is
     # disabled.
-    with self.test_session():
+    with self.cached_session():
       real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
       imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
       math_ops.complex(real_num, imag_num, name="complex")
@@ -212,7 +212,8 @@
 
   def testDefaultAttrStrippingNestedFunctions(self):
     """Verifies that default attributes are stripped from function node defs."""
-    with self.test_session():
+    with self.cached_session():
+
       @function.Defun(dtypes.float32, dtypes.float32)
       def f0(i, j):
         return math_ops.complex(i, j, name="double_nested_complex")
@@ -251,7 +252,7 @@
     meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
     meta_info_def.stripped_op_list.op.add()
 
-    with self.test_session():
+    with self.cached_session():
       meta_graph_def = meta_graph.create_meta_graph_def(
           meta_info_def=meta_info_def, graph_def=graph_def,
           strip_default_attrs=True)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 4cfd639..343f52f 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -55,8 +55,10 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import deprecation
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import lock_util
+from tensorflow.python.util import memory
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_stack
 from tensorflow.python.util.deprecation import deprecated_args
@@ -5363,6 +5365,7 @@
   computational graph).
 
   For example:
+
   ```python
   tf.enable_eager_execution()
 
@@ -5807,11 +5810,8 @@
   _STREAMING_MODEL_PORTS = "streaming_model_ports"
 
   @decorator_utils.classproperty
+  @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.")
   def VARIABLES(cls):  # pylint: disable=no-self-argument
-    logging.log_first_n(logging.WARN,
-                        "VARIABLES collection name is deprecated, please use "
-                        "GLOBAL_VARIABLES instead; VARIABLES will be removed "
-                        "after 2017-03-02.", 1)
     return cls.GLOBAL_VARIABLES
 
 
@@ -5825,23 +5825,11 @@
     graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
       after this function runs.
   """
-  # pylint: disable=protected-access
-  # OrderedDict, constructed on Graph creation, makes a simple reference loop
-  # and hides it in an __attribute in some Python versions. We don't need to
-  # throw an error if we can't find it, but if we do find it we can break the
-  # loop to avoid creating work for the garbage collector.
-  graph_operations = graph.get_operations()
-  problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
-  # pylint: enable=protected-access
-  if problematic_cycle:
-    try:
-      del problematic_cycle[0][:]
-    except TypeError:
-      # This is probably not one of the problematic Python versions. Continue
-      # with the rest of our cleanup.
-      pass
+  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
+
   # Now clean up Operation<->Graph reference cycles by clearing all of the
   # attributes for the Graph and its ops.
+  graph_operations = graph.get_operations()
   for op in graph_operations:
     op.__dict__ = {}
   graph.__dict__ = {}
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index ced0581..d59adf3 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@
 class ResourceTest(test_util.TensorFlowTestCase):
 
   def testBuildGraph(self):
-    with self.test_session():
+    with self.cached_session():
       pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
       test_ops.resource_create_op(pt).run()
 
   def testInitialize(self):
-    with self.test_session():
+    with self.cached_session():
       handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
       resources.register_resource(
           handle=handle,
@@ -100,35 +100,35 @@
         pass
 
   def testAddShape(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.zeros([2, 3])
       b = array_ops.ones([1, 3])
       c = a + b
       self.assertEqual([2, 3], c.shape)
 
   def testUnknownDim(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
       b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
       c = a + b
       self.assertEqual([2, None, 3], c.shape.as_list())
 
   def testUnknownShape(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
       b = array_ops.ones([1, 3])
       c = a + b
       self.assertEqual(tensor_shape.unknown_shape(), c.shape)
 
   def testScalarShape(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
       b = array_ops.ones([])
       c = a + b
       self.assertEqual(tensor_shape.scalar(), c.shape)
 
   def testShapeFunctionError(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.ones([1, 2, 3])
       b = array_ops.ones([4, 5, 6])
       with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@
 class IndexedSlicesTest(test_util.TensorFlowTestCase):
 
   def testToTensor(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
       indices = constant_op.constant([0, 2])
       dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@
       self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
 
   def testNegation(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
       indices = constant_op.constant([0, 2])
       x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@
       self.assertAllEqual(x.indices.eval(), [0, 2])
 
   def testScalarMul(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
       indices = constant_op.constant([0, 2])
       x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@
     self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
 
   def testConvertToTensorNestedArray(self):
-    with self.test_session():
+    with self.cached_session():
       values = [[2], [3], [5], [7]]
       tensor = ops.convert_to_tensor(values)
       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
       self.assertAllEqual(values, tensor.eval())
 
   def testShapeTuple(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(1)
       self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
 
@@ -328,14 +328,14 @@
       self.assertTrue(isinstance(converted, ops.EagerTensor))
 
   def testConvertToTensorNestedTuple(self):
-    with self.test_session():
+    with self.cached_session():
       values = ((2,), (3,), (5,), (7,))
       tensor = ops.convert_to_tensor(values)
       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
       self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
 
   def testConvertToTensorNestedTensors(self):
-    with self.test_session():
+    with self.cached_session():
       values = ((2,), (3,), (5,), (7,))
       tensor = ops.convert_to_tensor(
           [constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@
       self.assertAllEqual(values, tensor.eval())
 
   def testConvertToTensorNestedMix(self):
-    with self.test_session():
+    with self.cached_session():
       values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
       tensor = ops.convert_to_tensor(values)
       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
       self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
 
   def testConvertToTensorPreferred(self):
-    with self.test_session():
+    with self.cached_session():
       values = [2, 3, 5, 7]
       tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
       self.assertEqual(dtypes.float32, tensor.dtype)
 
-    with self.test_session():
+    with self.cached_session():
       # Convert empty tensor to anything.
       values = []
       tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
       self.assertEqual(dtypes.int64, tensor.dtype)
 
-    with self.test_session():
+    with self.cached_session():
       # The preferred dtype is a type error and will convert to
       # float32 instead.
       values = [1.23]
@@ -941,7 +941,7 @@
     self.assertEqual("bar_2", g.unique_name("bar"))
 
   def testNameAndVariableScope(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with sess.graph.name_scope("l0"):
         with variable_scope.variable_scope("l1"):
           with sess.graph.name_scope("l1") as scope:
@@ -2164,7 +2164,7 @@
 
     g = ops.Graph()
     with g.as_default():
-      with self.test_session():
+      with self.cached_session():
         # First ensure that graphs that are not building functions are
         # not escaped.
         function_with_variables("foo")
@@ -2416,11 +2416,11 @@
     return (a, b)
 
   def testNoLabel(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual((None, None), self._get_test_attrs())
 
   def testLabelMap(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a1 = self._get_test_attrs()
       with sess.graph._attr_scope({
           "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2454,12 +2454,12 @@
 class KernelLabelTest(test_util.TensorFlowTestCase):
 
   def testNoLabel(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(b"My label is: default",
                           test_ops.kernel_label().eval())
 
   def testLabelMap(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_1 = test_ops.kernel_label()
       # pylint: disable=protected-access
       with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2900,7 +2900,7 @@
 class TracebackTest(test_util.TensorFlowTestCase):
 
   def testTracebackWithStartLines(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant(2.0)
       sess.run(
           a,
diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc
index f227034..f6aef5b 100644
--- a/tensorflow/python/framework/python_op_gen_internal.cc
+++ b/tensorflow/python/framework/python_op_gen_internal.cc
@@ -15,18 +15,20 @@
 
 #include "tensorflow/python/framework/python_op_gen_internal.h"
 
+#include <float.h>
 #include <stdio.h>
+#include <iomanip>
 #include <sstream>
 #include <unordered_map>
 #include "tensorflow/core/framework/api_def.pb.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def.pb_text.h"
 #include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb_text.h"
 #include "tensorflow/core/framework/op_def_util.h"
 #include "tensorflow/core/framework/op_gen_lib.h"
-#include "tensorflow/core/framework/tensor.pb_text.h"
 #include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor.pb_text.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/framework/types.pb.h"
@@ -435,7 +437,12 @@
     if (std::isnan(value.f()) || std::isinf(value.f())) {
       return strings::StrCat("float('", value.f(), "')");
     } else {
-      return strings::StrCat(value.f());
+      // Use locale-independent conversion.
+      static_assert(FLT_DIG < 10, "FLT_DIG is too big");
+      std::ostringstream s;
+      s.imbue(std::locale::classic());
+      s << std::setprecision(FLT_DIG) << value.f();
+      return s.str();
     }
   } else if (type == "bool") {
     return value.b() ? "True" : "False";
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index 2bcfbc1..22423c4 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -45,7 +45,7 @@
       self.assertEqual(sp.dense_shape.dtype, dtypes.int64)
       self.assertEqual(sp.get_shape(), (4, 5))
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         value = sp.eval()
         self.assertAllEqual(indices, value.indices)
         self.assertAllEqual(values, value.values)
@@ -81,14 +81,14 @@
 class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
 
   def test_convert_dense(self):
-    with self.test_session():
+    with self.cached_session():
       value = [42, 43]
       from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
           value)
       self.assertAllEqual(value, from_value.eval())
 
   def test_convert_sparse(self):
-    with self.test_session():
+    with self.cached_session():
       indices = [[0, 1], [1, 0]]
       values = [42, 43]
       shape = [2, 2]
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index d6de45f..1d594e4 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -65,7 +65,7 @@
     self.assertFalse(c0.op in d.op.control_inputs)
     self.assertTrue(c.op in d.op.control_inputs)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c_out = sess.run([c])
       n_out = sess.run([n])
       d_out = sess.run([d])
@@ -144,7 +144,7 @@
     b = subscribe.subscribe(b,
                             lambda t: script_ops.py_func(sub, [t], [t.dtype]))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c_out = sess.run([c])
       d_out = sess.run([d])
 
@@ -204,7 +204,7 @@
     self.assertIs(c_sub, c_sub3)
 
     # Expect the three side effect graphs to have been evaluated.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([c_sub])
     self.assertIn('graph1', shared)
     self.assertIn('graph2', shared)
@@ -227,7 +227,7 @@
         v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
     self.assertTrue(subscribe._is_subscribed_identity(v1_sub))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize the variables first.
       sess.run([v1.initializer])
       sess.run([v2.initializer])
@@ -272,7 +272,7 @@
     self.assertIs(tensor_array_sub, tensor_array.handle)
     self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([reader])
     self.assertEqual(0, len(shared))
 
@@ -303,7 +303,7 @@
     subscribe.subscribe(sparse_add.op.outputs,
                         lambda t: script_ops.py_func(sub, [t], [t.dtype]))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([neg])
 
     # All three ops have been processed.
@@ -374,7 +374,7 @@
     # Verify that sub(x1) and sub(branch) are not.
     self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(cond)
 
     self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 11b681d..3c2a736 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -606,8 +606,8 @@
       slice.
 
     Raises:
-      ValueError: If `key` is a slice, and any of its elements are negative, or
-        if `self` is completely unknown and the step is set.
+      ValueError: If `key` is a slice and `self` is completely unknown and
+        the step is set.
     """
     if self._dims is not None:
       if isinstance(key, slice):
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index b14290c..26170b0 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -367,7 +367,7 @@
     A `TensorProto`. Depending on the type, it may contain data in the
     "tensor_content" attribute, which is not directly useful to Python programs.
     To access the values you should convert the proto back to a numpy ndarray
-    with `tensor_util.MakeNdarray(proto)`.
+    with `tf.make_ndarray(proto)`.
 
     If `values` is a `TensorProto`, it is immediately returned; `dtype` and
     `shape` are ignored.
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 395cf43..bdf759f 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -768,7 +768,7 @@
       def __array__(self, dtype=None):
         return np.asarray(self.array, dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ma = MockArray(np.array([10, 20, 30]))
       t = ops.convert_to_tensor(ma)
       a = sess.run(t)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b5388ad..b739823 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -69,6 +69,7 @@
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import compat
+from tensorflow.python.util import memory
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.protobuf import compare
@@ -413,15 +414,13 @@
     The wrapped function
   """
 
-  # pylint: disable=protected-access
   def wrapper(*args, **kwargs):
-    prev_value = control_flow_ops._ENABLE_COND_V2
-    control_flow_ops._ENABLE_COND_V2 = True
+    prev_value = control_flow_ops.ENABLE_COND_V2
+    control_flow_ops.ENABLE_COND_V2 = True
     try:
       fn(*args, **kwargs)
     finally:
-      control_flow_ops._ENABLE_COND_V2 = prev_value
-  # pylint: enable=protected-access
+      control_flow_ops.ENABLE_COND_V2 = prev_value
 
   return wrapper
 
@@ -438,7 +437,7 @@
   Returns:
     cls with new test methods added
   """
-  if control_flow_ops._ENABLE_COND_V2:
+  if control_flow_ops.ENABLE_COND_V2:
     return cls
 
   for name, value in cls.__dict__.copy().items():
@@ -465,29 +464,31 @@
       f(self, **kwargs)
       gc.collect()
       previous_count = len(gc.get_objects())
-      collection_sizes_before = {
-          collection: len(ops.get_collection(collection))
-          for collection in ops.get_default_graph().collections
-      }
+      if ops.has_default_graph():
+        collection_sizes_before = {
+            collection: len(ops.get_collection(collection))
+            for collection in ops.get_default_graph().collections
+        }
       for _ in range(3):
         f(self, **kwargs)
       # Note that gc.get_objects misses anything that isn't subject to garbage
       # collection (C types). Collections are a common source of leaks, so we
       # test for collection sizes explicitly.
-      for collection_key in ops.get_default_graph().collections:
-        collection = ops.get_collection(collection_key)
-        size_before = collection_sizes_before.get(collection_key, 0)
-        if len(collection) > size_before:
-          raise AssertionError(
-              ("Collection %s increased in size from "
-               "%d to %d (current items %s).") % (collection_key, size_before,
-                                                  len(collection), collection))
-        # Make sure our collection checks don't show up as leaked memory by
-        # removing references to temporary variables.
-        del collection
-        del collection_key
-        del size_before
-      del collection_sizes_before
+      if ops.has_default_graph():
+        for collection_key in ops.get_default_graph().collections:
+          collection = ops.get_collection(collection_key)
+          size_before = collection_sizes_before.get(collection_key, 0)
+          if len(collection) > size_before:
+            raise AssertionError(
+                ("Collection %s increased in size from "
+                 "%d to %d (current items %s).") %
+                (collection_key, size_before, len(collection), collection))
+          # Make sure our collection checks don't show up as leaked memory by
+          # removing references to temporary variables.
+          del collection
+          del collection_key
+          del size_before
+        del collection_sizes_before
       gc.collect()
       # There should be no new Python objects hanging around.
       new_count = len(gc.get_objects())
@@ -535,15 +536,16 @@
 
     tensors_before = set(
         id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
-    if context.executing_eagerly():
-      f(self, **kwargs)
-      ops.reset_default_graph()
-    else:
-      # Run the test in a new graph so that collections get cleared when it's
-      # done, but inherit the graph key so optimizers behave.
-      outside_graph_key = ops.get_default_graph()._graph_key
-      with ops.Graph().as_default():
-        ops.get_default_graph()._graph_key = outside_graph_key
+    outside_executed_eagerly = context.executing_eagerly()
+    # Run the test in a new graph so that collections get cleared when it's
+    # done, but inherit the graph key so optimizers behave.
+    outside_graph_key = ops.get_default_graph()._graph_key
+    with ops.Graph().as_default():
+      ops.get_default_graph()._graph_key = outside_graph_key
+      if outside_executed_eagerly:
+        with context.eager_mode():
+          f(self, **kwargs)
+      else:
         f(self, **kwargs)
     # Make an effort to clear caches, which would otherwise look like leaked
     # Tensors.
@@ -777,7 +779,7 @@
 
       def run_eagerly(self, **kwargs):
         if not use_gpu:
-          with ops.device("/cpu:0"):
+          with ops.device("/device:CPU:0"):
             f(self, **kwargs)
         else:
           f(self, **kwargs)
@@ -1072,13 +1074,9 @@
     if context.executing_eagerly():
       yield None
     else:
-      sess = self._create_session(graph, config, use_gpu, force_gpu)
-      with self._constrain_devices_and_set_default(
-          sess, use_gpu, force_gpu) as constrained_sess:
-        # We need to do this to make sure the session closes, otherwise, even
-        # if the user does with self.session():, it will not close the session.
-        with constrained_sess:
-          yield constrained_sess
+      with self._create_session(graph, config, force_gpu) as sess:
+        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
+          yield sess
 
   @contextlib.contextmanager
   def cached_session(self,
@@ -1126,10 +1124,11 @@
     if context.executing_eagerly():
       yield None
     else:
-      with self._get_cached_session(
-          graph, config, use_gpu, force_gpu,
-          crash_if_inconsistent_args=True) as sess:
-        yield sess
+      sess = self._get_cached_session(
+          graph, config, force_gpu, crash_if_inconsistent_args=True)
+      with self._constrain_devices_and_set_default(sess, use_gpu,
+                                                   force_gpu) as cached:
+        yield cached
 
   @contextlib.contextmanager
   def test_session(self,
@@ -1145,10 +1144,11 @@
       yield None
     else:
       if graph is None:
-        with self._get_cached_session(
-            graph, config, use_gpu, force_gpu,
-            crash_if_inconsistent_args=False) as sess:
-          yield sess
+        sess = self._get_cached_session(
+            graph, config, force_gpu, crash_if_inconsistent_args=False)
+        with self._constrain_devices_and_set_default(sess, use_gpu,
+                                                     force_gpu) as cached:
+          yield cached
       else:
         with self.session(graph, config, use_gpu, force_gpu) as sess:
           yield sess
@@ -1326,9 +1326,17 @@
   def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
     a = self._GetNdArray(a)
     b = self._GetNdArray(b)
-    self.assertEqual(
-        a.shape, b.shape,
-        "Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
+    # When the array rank is small, print its contents. Numpy array printing is
+    # implemented using inefficient recursion so prints can cause tests to
+    # time out.
+    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
+      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
+                            "%s.") % (a.shape, b.shape, b)
+    else:
+      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
+                                                                     b.shape)
+    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+
     if not np.allclose(a, b, rtol=rtol, atol=atol):
       # Prints more details than np.testing.assert_allclose.
       #
@@ -1654,7 +1662,7 @@
         if any of the elements do not fall in the specified range.
     """
     target = self._GetNdArray(target)
-    if not (np.issubdtype(target.dtype, np.float) or
+    if not (np.issubdtype(target.dtype, np.floating) or
             np.issubdtype(target.dtype, np.integer)):
       raise AssertionError(
           "The value of %s does not have an ordered numeric type, instead it "
@@ -1831,94 +1839,78 @@
         elif use_gpu:
           yield sess
         else:
-          with sess.graph.device("/cpu:0"):
+          with sess.graph.device("/device:CPU:0"):
             yield sess
 
-  def _create_session(self, graph, config, use_gpu, force_gpu):
+  def _create_session(self, graph, config, force_gpu):
     """See session() for details."""
-    if context.executing_eagerly():
-      return None
-    else:
+    def prepare_config(config):
+      """Returns a config for sessions.
 
-      def prepare_config(config):
-        """Returns a config for sessions.
+      Args:
+        config: An optional config_pb2.ConfigProto to use to configure the
+          session.
 
-        Args:
-          config: An optional config_pb2.ConfigProto to use to configure the
-            session.
-        Returns:
-          A config_pb2.ConfigProto object.
-        """
-        if config is None:
-          config = config_pb2.ConfigProto()
-          config.allow_soft_placement = not force_gpu
-          config.gpu_options.per_process_gpu_memory_fraction = 0.3
-        elif force_gpu and config.allow_soft_placement:
-          config = config_pb2.ConfigProto().CopyFrom(config)
-          config.allow_soft_placement = False
-        # Don't perform optimizations for tests so we don't inadvertently run
-        # gpu ops on cpu
-        config.graph_options.optimizer_options.opt_level = -1
-        config.graph_options.rewrite_options.constant_folding = (
-            rewriter_config_pb2.RewriterConfig.OFF)
-        config.graph_options.rewrite_options.arithmetic_optimization = (
-            rewriter_config_pb2.RewriterConfig.OFF)
-        return config
+      Returns:
+        A config_pb2.ConfigProto object.
+      """
+      # TODO(b/114333779): Enforce allow_soft_placement=False when
+      # use_gpu=False. Currently many tests rely on the fact that any device
+      # will be used even when a specific device is supposed to be used.
+      allow_soft_placement = not force_gpu
+      if config is None:
+        config = config_pb2.ConfigProto()
+        config.allow_soft_placement = allow_soft_placement
+        config.gpu_options.per_process_gpu_memory_fraction = 0.3
+      elif not allow_soft_placement and config.allow_soft_placement:
+        config_copy = config_pb2.ConfigProto()
+        config_copy.CopyFrom(config)
+        config = config_copy
+        config.allow_soft_placement = False
+      # Don't perform optimizations for tests so we don't inadvertently run
+      # gpu ops on cpu
+      config.graph_options.optimizer_options.opt_level = -1
+      config.graph_options.rewrite_options.constant_folding = (
+          rewriter_config_pb2.RewriterConfig.OFF)
+      config.graph_options.rewrite_options.arithmetic_optimization = (
+          rewriter_config_pb2.RewriterConfig.OFF)
+      return config
 
-      return ErrorLoggingSession(graph=graph, config=prepare_config(config))
+    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
 
-  @contextlib.contextmanager
   def _get_cached_session(self,
                           graph=None,
                           config=None,
-                          use_gpu=False,
                           force_gpu=False,
                           crash_if_inconsistent_args=True):
     """See cached_session() for documentation."""
-    if context.executing_eagerly():
-      yield None
+    if self._cached_session is None:
+      sess = self._create_session(
+          graph=graph, config=config, force_gpu=force_gpu)
+      self._cached_session = sess
+      self._cached_graph = graph
+      self._cached_config = config
+      self._cached_force_gpu = force_gpu
+      return sess
     else:
-      if self._cached_session is None:
-        sess = self._create_session(
-            graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu)
-        self._cached_session = sess
-        self._cached_graph = graph
-        self._cached_config = config
-        self._cached_use_gpu = use_gpu
-        self._cached_force_gpu = force_gpu
-        with self._constrain_devices_and_set_default(
-            sess, use_gpu, force_gpu) as constrained_sess:
-          yield constrained_sess
-      else:
-        if crash_if_inconsistent_args and self._cached_graph is not graph:
-          raise ValueError("The graph used to get the cached session is "
-                           "different than the one that was used to create the "
-                           "session. Maybe create a new session with "
-                           "self.session()")
-        if crash_if_inconsistent_args and self._cached_config is not config:
-          raise ValueError("The config used to get the cached session is "
-                           "different than the one that was used to create the "
-                           "session. Maybe create a new session with "
-                           "self.session()")
-        if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu:
-          raise ValueError(
-              "The use_gpu value used to get the cached session is "
-              "different than the one that was used to create the "
-              "session. Maybe create a new session with "
-              "self.session()")
-        if crash_if_inconsistent_args and (self._cached_force_gpu is
-                                           not force_gpu):
-          raise ValueError(
-              "The force_gpu value used to get the cached session is "
-              "different than the one that was used to create the "
-              "session. Maybe create a new session with "
-              "self.session()")
-        # If you modify this logic, make sure to modify it in _create_session
-        # as well.
-        sess = self._cached_session
-        with self._constrain_devices_and_set_default(
-            sess, use_gpu, force_gpu) as constrained_sess:
-          yield constrained_sess
+      if crash_if_inconsistent_args and self._cached_graph is not graph:
+        raise ValueError("The graph used to get the cached session is "
+                         "different than the one that was used to create the "
+                         "session. Maybe create a new session with "
+                         "self.session()")
+      if crash_if_inconsistent_args and self._cached_config is not config:
+        raise ValueError("The config used to get the cached session is "
+                         "different than the one that was used to create the "
+                         "session. Maybe create a new session with "
+                         "self.session()")
+      if crash_if_inconsistent_args and (self._cached_force_gpu is
+                                         not force_gpu):
+        raise ValueError(
+            "The force_gpu value used to get the cached session is "
+            "different than the one that was used to create the "
+            "session. Maybe create a new session with "
+            "self.session()")
+      return self._cached_session
 
 
 @tf_export("test.create_local_cluster")
@@ -2023,3 +2015,42 @@
   with graph.as_default():
     importer.import_graph_def(graph_def)
   assert graph.graph_def_versions.producer, producer_version
+
+
+def dismantle_func_graph(func_graph):
+  """Removes reference cycles in `func_graph` FuncGraph.
+
+  Helpful for making sure the garbage collector doesn't need to run when
+  the FuncGraph goes out of scope, e.g. in tests using defun with
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+  Args:
+    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
+      after this function.
+  """
+  # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+  # Clearing captures using clear() leaves some cycles around.
+  while func_graph.captures:
+    func_graph.captures.popitem()
+  memory.dismantle_ordered_dict(func_graph.captures)
+  ops.dismantle_graph(func_graph)
+
+
+def dismantle_polymorphic_function(func):
+  """Removes reference cycles in PolymorphicFunction `func`.
+
+  Helpful for making sure the garbage collector doesn't need to run when
+  PolymorphicFunction goes out of scope, e.g. in tests using defun with
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+  Args:
+    func: A `PolymorphicFunction` object to destroy. `func` is unusable
+      after this function.
+  """
+  # TODO(b/115366440): Delete this method when a custom OrderedDict is added
+  cache = func._function_cache  # pylint: disable=protected-access
+  for concrete_func in cache.values():
+    dismantle_func_graph(concrete_func.graph)
+  while cache:
+    cache.popitem()
+  memory.dismantle_ordered_dict(cache)
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index a0939f9..c4f8fa9 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -71,9 +71,6 @@
           with self.cached_session(graph=ops.Graph()) as sess2:
             pass
         with self.assertRaises(ValueError):
-          with self.cached_session(use_gpu=True) as sess2:
-            pass
-        with self.assertRaises(ValueError):
           with self.cached_session(force_gpu=True) as sess2:
             pass
     # We make sure that test_session will cache the session even after the
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 7246341..b521b14 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -337,11 +337,6 @@
     size = "large",
     srcs = ["layers/convolutional_test.py"],
     srcs_version = "PY2AND3",
-    tags = [
-        "manual",
-        "noasan",  # times out b/63678675
-        "notsan",
-    ],
     deps = [
         ":keras",
         "//tensorflow/python:client_testlib",
@@ -700,6 +695,20 @@
 )
 
 py_test(
+    name = "feature_columns_integration_test",
+    size = "small",
+    srcs = ["engine/feature_columns_integration_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python/feature_column:feature_column_py",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
     name = "training_eager_test",
     size = "medium",
     srcs = ["engine/training_eager_test.py"],
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index b52ab7f..529b07d 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -73,7 +73,16 @@
 # This dictionary holds a mapping {graph: learning_phase}.
 # A learning phase is a bool tensor used to run Keras models in
 # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
-_GRAPH_LEARNING_PHASES = {}
+_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
+
+
+# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
+# We keep a separate reference to it to make sure it does not get removed from
+# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
+# string because strings are not weakly-referencable.
+class _DummyEagerGraph(object):
+  pass
+_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
 
 # This boolean flag can be set to True to leave variable initialization
 # up to the user.
@@ -96,11 +105,11 @@
 
 # This dictionary holds a mapping between a graph and variables to initialize
 # in the graph.
-_GRAPH_VARIABLES = {}
+_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
 
 # This dictionary holds a mapping between a graph and TF optimizers created in
 # the graph.
-_GRAPH_TF_OPTIMIZERS = {}
+_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
 
 
 @tf_export('keras.backend.backend')
@@ -359,10 +368,10 @@
       Learning phase (scalar integer tensor or Python integer).
   """
   if context.executing_eagerly():
-    if 'eager' not in _GRAPH_LEARNING_PHASES:
+    if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
       # Fallback to inference mode as default.
       return 0
-    return _GRAPH_LEARNING_PHASES['eager']
+    return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
 
   graph = ops.get_default_graph()
   if graph not in _GRAPH_LEARNING_PHASES:
@@ -386,7 +395,7 @@
   if value not in {0, 1}:
     raise ValueError('Expected learning phase to be 0 or 1.')
   if context.executing_eagerly():
-    _GRAPH_LEARNING_PHASES['eager'] = value
+    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
   else:
     _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
 
@@ -415,7 +424,7 @@
   finally:
     # Restore learning phase to initial value.
     if context.executing_eagerly():
-      _GRAPH_LEARNING_PHASES['eager'] = previous_value
+      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
     else:
       _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
 
@@ -443,13 +452,7 @@
     session = default_session
   else:
     if _SESSION is None:
-      if not os.environ.get('OMP_NUM_THREADS'):
-        config = config_pb2.ConfigProto(allow_soft_placement=True)
-      else:
-        num_thread = int(os.environ.get('OMP_NUM_THREADS'))
-        config = config_pb2.ConfigProto(
-            intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
-      _SESSION = session_module.Session(config=config)
+      _SESSION = session_module.Session(config=get_default_session_config())
     session = _SESSION
   if not _MANUAL_VAR_INIT:
     with session.graph.as_default():
@@ -468,6 +471,16 @@
   _SESSION = session
 
 
+def get_default_session_config():
+  if not os.environ.get('OMP_NUM_THREADS'):
+    config = config_pb2.ConfigProto(allow_soft_placement=True)
+  else:
+    num_thread = int(os.environ.get('OMP_NUM_THREADS'))
+    config = config_pb2.ConfigProto(
+        intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
+  return config
+
+
 # DEVICE MANIPULATION
 
 
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 266af56..2f271c4 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -279,7 +279,7 @@
           keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
 
   def test_function_tf_run_options_with_run_metadata(self):
-    with self.test_session():
+    with self.cached_session():
       x_placeholder = keras.backend.placeholder(shape=())
       y_placeholder = keras.backend.placeholder(shape=())
 
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 7675a65..b6fae19 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -63,7 +63,7 @@
     if h5py is None:
       return  # Skip test if models cannot be saved.
 
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
 
       temp_dir = self.get_temp_dir()
@@ -226,7 +226,7 @@
           mode='unknown')
 
   def test_EarlyStopping(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(123)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -265,7 +265,7 @@
             verbose=0)
 
   def test_EarlyStopping_reuse(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       patience = 3
       data = np.random.random((100, 1))
@@ -287,7 +287,7 @@
       assert len(hist.epoch) >= patience
 
   def test_EarlyStopping_with_baseline(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       baseline = 0.5
       (data, labels), _ = testing_utils.get_test_data(
@@ -321,7 +321,7 @@
     monitor.on_epoch_end(0, logs={'loss': 0.})
 
   def test_LearningRateScheduler(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -368,7 +368,7 @@
               model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
 
   def test_ReduceLROnPlateau(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -470,7 +470,7 @@
     self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
 
   def test_CSVLogger(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       temp_dir = self.get_temp_dir()
       self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
@@ -549,7 +549,7 @@
     tmpdir = self.get_temp_dir()
     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
 
-    with self.test_session():
+    with self.cached_session():
       fp = os.path.join(tmpdir, 'test.csv')
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -601,7 +601,7 @@
       assert 'nan' in values[-1], 'The last epoch was not logged.'
 
   def test_TerminateOnNaN(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -666,7 +666,7 @@
         i %= max_batch_index
 
     # case: Sequential
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.Dense(
@@ -743,7 +743,7 @@
     tmpdir = self.get_temp_dir()
     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
 
-    with self.test_session():
+    with self.cached_session():
       filepath = os.path.join(tmpdir, 'logs')
 
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -815,7 +815,7 @@
     tmpdir = self.get_temp_dir()
     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
 
-    with self.test_session():
+    with self.cached_session():
       filepath = os.path.join(tmpdir, 'logs')
 
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -925,7 +925,7 @@
     y_test = keras.utils.to_categorical(y_test)
     y_train = keras.utils.to_categorical(y_train)
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.Dense(
@@ -969,7 +969,7 @@
       while True:
         yield x, y
 
-    with self.test_session():
+    with self.cached_session():
       model = testing_utils.get_small_sequential_mlp(
           num_hidden=10, num_classes=10, input_dim=100)
       model.compile(
@@ -1011,7 +1011,7 @@
       os.name == 'nt',
       'use_multiprocessing=True does not work on windows properly.')
   def test_LambdaCallback(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
@@ -1055,7 +1055,7 @@
       assert not t.is_alive()
 
   def test_TensorBoard_with_ReduceLROnPlateau(self):
-    with self.test_session():
+    with self.cached_session():
       temp_dir = self.get_temp_dir()
       self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
 
@@ -1194,7 +1194,7 @@
   def test_RemoteMonitorWithJsonPayload(self):
     if requests is None:
       self.skipTest('`requests` required to run this test')
-    with self.test_session():
+    with self.cached_session():
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=TRAIN_SAMPLES,
           test_samples=TEST_SAMPLES,
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index b6b05c0..cb19a41 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -1001,7 +1001,7 @@
       self.build(input_shape)
 
       with context.graph_mode():
-        graph = eager_function.CapturingGraph()
+        graph = eager_function.FuncGraph('graph')
         with graph.as_default():
           if isinstance(input_shape, list):
             inputs = [generate_placeholders_from_shape(shape)
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index fcb0733..b28df75 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,8 +17,10 @@
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import backend
+from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import distribute as distribute_lib
@@ -46,7 +48,7 @@
       assign_ops.append(distribution_strategy.unwrap(sw.assign(w)))
 
     weights = weights[num_param:]
-  backend.get_session().run(assign_ops)
+  K.get_session().run(assign_ops)
 
 
 def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs,
@@ -211,7 +213,10 @@
   # validate the input and targets.
   x_values_list = validate_per_device_inputs(distribution_strategy, x)
 
-  y_values_list = validate_per_device_inputs(distribution_strategy, y)
+  if y is not None:
+    y_values_list = validate_per_device_inputs(distribution_strategy, y)
+  else:
+    y_values_list = None
 
   # Return the unwrapped values to avoid calling `unwrap` a second time.
   return x_values_list, y_values_list
@@ -269,3 +274,91 @@
     if x_shape != x_values[i].get_shape().as_list():
       raise ValueError('Input tensor shapes do not match for distributed tensor'
                        ' inputs {}'.format(x))
+
+
+def configure_and_create_session(distribution_strategy):
+  """Configure session config and create a session with it."""
+  # TODO(priyag): Throw error if a session already exists.
+  session_config = K.get_default_session_config()
+  distribution_strategy.configure(session_config)
+
+  if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+    # TODO(priyag): Remove this workaround when Distributed Coordinator is
+    # integrated with keras and we can create a session from there.
+    master = distribution_strategy._tpu_cluster_resolver.master()  # pylint: disable=protected-access
+    session = session_module.Session(config=session_config, target=master)
+  else:
+    session = session_module.Session(config=session_config)
+
+  K.set_session(session)
+
+
+def validate_inputs(x, y):
+  """Validate inputs when using DistributionStrategy.
+
+  Args:
+    x: Model Inputs.
+    y: Model Targets.
+
+  Raises:
+    ValueError: if input is not a Dataset or a numpy array.
+  """
+  if isinstance(x, list) or isinstance(y, list):
+    raise ValueError('DistributionStrategy does not support lists of numpy'
+                     'arrays. You must pass a Dataset object or a numpy array '
+                     'as input.')
+
+  if isinstance(x, dict) or isinstance(y, dict):
+    raise ValueError('DistributionStrategy does not support inputs of type '
+                     'dict. You must pass a Dataset object or a numpy array as '
+                     'input.')
+
+  if isinstance(x, iterator_ops.Iterator) or \
+      isinstance(y, iterator_ops.Iterator):
+    raise ValueError('DistributionStrategy does not support inputs of type '
+                     'Iterator. You must pass a Dataset object or a numpy '
+                     'array as input.')
+
+
+def get_input_batch_params(first_x_value, batch_size, current_strategy):
+  """Calculate the number of batches and steps/steps_per_epoch.
+
+  Args:
+    first_x_value: This is the first input numpy array that is passed in as the
+      model input.
+    batch_size: The specified batch_size or the default batch_size of 32.
+    current_strategy: The current DistributionStrategy used to compile the
+      model.
+
+  Returns:
+    The steps or steps_per_epoch argument depending on if a user is
+    calling `fit`, `evaluate` or `predict`.
+
+  Raises:
+    ValueError: If the number of batches or steps evaluates to 0.
+
+  """
+  num_batches = first_x_value.shape[0] // batch_size
+  if not num_batches:
+    raise ValueError('Please specify a batch_size that is smaller than'
+                     'the number of input samples %d.' % first_x_value.shape[0])
+  # TODO(anjalisridhar): TPU currently supports using the num_towers property.
+  # We might want to look into implementing worker_devices. In multi worker
+  # strategy, perhaps num_towers works better?
+  steps = num_batches // current_strategy.num_towers
+  if not steps:
+    # TODO(anjalisridhar): Number of towers in the error message may not convey
+    # what we want to the user. Is there another terminology that we can use
+    # that is consistent across different strategies.
+    raise ValueError('The number of batches %d is smaller than the number '
+                     'of towers %d used for DistributionStrategy. ' %
+                     num_batches, current_strategy.num_towers)
+  return steps
+
+
+def get_batch_dimension(iterator):
+  shapes = nest.flatten(iterator.output_shapes)
+  # Take the batch size from the first element, as it should be the same for
+  # all.
+  dims = shapes[0].dims
+  return dims[0] if dims else None
diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py
new file mode 100644
index 0000000..e0478ee
--- /dev/null
+++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py
@@ -0,0 +1,237 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests specific to Feature Columns integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
+
+
+class TestDNNModel(keras.models.Model):
+
+  def __init__(self, feature_columns, units, name=None, **kwargs):
+    super(TestDNNModel, self).__init__(name=name, **kwargs)
+    self._input_layer = fc.FeatureLayer(feature_columns, name='input_layer')
+    self._dense_layer = keras.layers.Dense(units, name='dense_layer')
+
+  def call(self, features):
+    net = self._input_layer(features)
+    net = self._dense_layer(net)
+    return net
+
+
+class FeatureColumnsIntegrationTest(test.TestCase):
+  """Most Sequential model API tests are covered in `training_test.py`.
+
+  """
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_sequential_model(self):
+    columns = [fc.numeric_column('a')]
+    model = keras.models.Sequential([
+        fc.FeatureLayer(columns),
+        keras.layers.Dense(64, activation='relu'),
+        keras.layers.Dense(20, activation='softmax')
+    ])
+    model.compile(
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        loss='categorical_crossentropy',
+        metrics=['accuracy'])
+
+    x = {'a': np.random.random((10, 1))}
+    y = np.random.randint(20, size=(10, 1))
+    y = keras.utils.to_categorical(y, num_classes=20)
+    model.fit(x, y, epochs=1, batch_size=5)
+    model.fit(x, y, epochs=1, batch_size=5)
+    model.evaluate(x, y, batch_size=5)
+    model.predict(x, batch_size=5)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_sequential_model_with_ds_input(self):
+    columns = [fc.numeric_column('a')]
+    model = keras.models.Sequential([
+        fc.FeatureLayer(columns),
+        keras.layers.Dense(64, activation='relu'),
+        keras.layers.Dense(20, activation='softmax')
+    ])
+    model.compile(
+        optimizer=rmsprop.RMSPropOptimizer(1e-3),
+        loss='categorical_crossentropy',
+        metrics=['accuracy'])
+
+    y = np.random.randint(20, size=(100, 1))
+    y = keras.utils.to_categorical(y, num_classes=20)
+    x = {'a': np.random.random((100, 1))}
+    ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+    ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+    ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+    model.fit(ds, steps_per_epoch=1)
+    model.fit(ds, steps_per_epoch=1)
+    model.evaluate(ds, steps=1)
+    model.predict(ds, steps=1)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_subclassed_model_with_feature_columns(self):
+    col_a = fc.numeric_column('a')
+    col_b = fc.numeric_column('b')
+
+    dnn_model = TestDNNModel([col_a, col_b], 20)
+
+    dnn_model.compile(
+        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+        loss='categorical_crossentropy',
+        metrics=['accuracy'])
+
+    x = {'a': np.random.random((10, 1)), 'b': np.random.random((10, 1))}
+    y = np.random.randint(20, size=(10, 1))
+    y = keras.utils.to_categorical(y, num_classes=20)
+    dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+    dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+    dnn_model.evaluate(x=x, y=y, batch_size=5)
+    dnn_model.predict(x=x, batch_size=5)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_subclassed_model_with_feature_columns_with_ds_input(self):
+    col_a = fc.numeric_column('a')
+    col_b = fc.numeric_column('b')
+
+    dnn_model = TestDNNModel([col_a, col_b], 20)
+
+    dnn_model.compile(
+        optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+        loss='categorical_crossentropy',
+        metrics=['accuracy'])
+
+    y = np.random.randint(20, size=(100, 1))
+    y = keras.utils.to_categorical(y, num_classes=20)
+    x = {'a': np.random.random((100, 1)), 'b': np.random.random((100, 1))}
+    ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+    ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+    ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+    dnn_model.fit(ds, steps_per_epoch=1)
+    dnn_model.fit(ds, steps_per_epoch=1)
+    dnn_model.evaluate(ds, steps=1)
+    dnn_model.predict(ds, steps=1)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def DISABLED_test_function_model_feature_layer_input(self):
+    col_a = fc.numeric_column('a')
+    col_b = fc.numeric_column('b')
+
+    feature_layer = fc.FeatureLayer([col_a, col_b], name='fc')
+    dense = keras.layers.Dense(4)
+
+    # This seems problematic.... We probably need something for FeatureLayer
+    # the way Input is for InputLayer.
+    output = dense(feature_layer)
+
+    model = keras.models.Model([feature_layer], [output])
+
+    optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+    loss = 'mse'
+    loss_weights = [1., 0.5]
+    model.compile(
+        optimizer,
+        loss,
+        metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+        loss_weights=loss_weights)
+
+    data = ({'a': np.arange(10), 'b': np.arange(10)}, np.arange(10, 20))
+    print(model.fit(*data, epochs=1))
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def DISABLED_test_function_model_multiple_feature_layer_inputs(self):
+    col_a = fc.numeric_column('a')
+    col_b = fc.numeric_column('b')
+    col_c = fc.numeric_column('c')
+
+    fc1 = fc.FeatureLayer([col_a, col_b], name='fc1')
+    fc2 = fc.FeatureLayer([col_b, col_c], name='fc2')
+    dense = keras.layers.Dense(4)
+
+    # This seems problematic.... We probably need something for FeatureLayer
+    # the way Input is for InputLayer.
+    output = dense(fc1) + dense(fc2)
+
+    model = keras.models.Model([fc1, fc2], [output])
+
+    optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+    loss = 'mse'
+    loss_weights = [1., 0.5]
+    model.compile(
+        optimizer,
+        loss,
+        metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+        loss_weights=loss_weights)
+
+    data_list = ([{
+        'a': np.arange(10),
+        'b': np.arange(10)
+    }, {
+        'b': np.arange(10),
+        'c': np.arange(10)
+    }], np.arange(10, 100))
+    print(model.fit(*data_list, epochs=1))
+
+    data_bloated_list = ([{
+        'a': np.arange(10),
+        'b': np.arange(10),
+        'c': np.arange(10)
+    }, {
+        'a': np.arange(10),
+        'b': np.arange(10),
+        'c': np.arange(10)
+    }], np.arange(10, 100))
+    print(model.fit(*data_bloated_list, epochs=1))
+
+    data_dict = ({
+        'fc1': {
+            'a': np.arange(10),
+            'b': np.arange(10)
+        },
+        'fc2': {
+            'b': np.arange(10),
+            'c': np.arange(10)
+        }
+    }, np.arange(10, 100))
+    print(model.fit(*data_dict, epochs=1))
+
+    data_bloated_dict = ({
+        'fc1': {
+            'a': np.arange(10),
+            'b': np.arange(10),
+            'c': np.arange(10)
+        },
+        'fc2': {
+            'a': np.arange(10),
+            'b': np.arange(10),
+            'c': np.arange(10)
+        }
+    }, np.arange(10, 100))
+    print(model.fit(*data_bloated_dict, epochs=1))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index cd74e36..5ef8d13 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -770,7 +770,7 @@
       # and graph building, the variables created after building the model in
       # a Graph are still valid when executing eagerly.
       with context.graph_mode():
-        graph = eager_function.CapturingGraph()
+        graph = eager_function.FuncGraph('graph')
         with graph.as_default():
           if isinstance(input_shape, list):
             x = [base_layer.generate_placeholders_from_shape(shape)
@@ -1355,7 +1355,9 @@
     ```
     """
     if not self._is_graph_network:
-      raise NotImplementedError
+      raise NotImplementedError(
+          'Currently `save` requires model to be a graph network. Consider '
+          'using `save_weights`, in order to save the weights of the model.')
 
     from tensorflow.python.keras.models import save_model  # pylint: disable=g-import-not-at-top
     save_model(self, filepath, overwrite, include_optimizer)
@@ -1574,7 +1576,10 @@
     def get_json_type(obj):
       # If obj is any numpy type
       if type(obj).__module__ == np.__name__:
-        return obj.item()
+        if isinstance(obj, np.ndarray):
+          return obj.tolist()
+        else:
+          return obj.item()
 
       # If obj is a python 'type'
       if type(obj).__name__ == type.__name__:
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index a2eed7c..a2f31fd 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -248,7 +248,7 @@
       loss = convert_custom_objects(training_config['loss'])
       metrics = convert_custom_objects(training_config['metrics'])
       weighted_metrics = convert_custom_objects(
-          training_config['weighted_metrics'])
+          training_config.get('weighted_metrics', None))
       sample_weight_mode = training_config['sample_weight_mode']
       loss_weights = training_config['loss_weights']
 
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 441f3f4..148dd23 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -48,7 +48,7 @@
 class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
 
   def test_weight_loading(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(2,))
       x = keras.layers.Dense(3)(a)
       b = keras.layers.Dense(1)(x)
@@ -208,7 +208,7 @@
       }))
   def test_preprocess_weights_for_loading_rnn_should_be_idempotent(
       self, layer_class, layer_args):
-    with self.test_session():
+    with self.cached_session():
       layer = layer_class(**layer_args)
       layer.build(input_shape=layer_args.get('input_shape'))
       weights1 = layer.get_weights()
@@ -232,7 +232,7 @@
     batch_size = 5
     num_classes = 2
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
       model.add(keras.layers.Dense(num_classes))
@@ -261,7 +261,7 @@
     num_hidden = 5
     input_dim = 3
     num_classes = 2
-    with self.test_session():
+    with self.cached_session():
       ref_model = keras.models.Sequential()
       ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
                                        name='d1'))
@@ -298,7 +298,7 @@
     num_hidden = 5
     input_dim = 3
     num_classes = 2
-    with self.test_session():
+    with self.cached_session():
       ref_model = keras.models.Sequential()
       ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
                                        name='d1'))
@@ -333,7 +333,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.RepeatVector(3))
@@ -378,7 +378,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.RepeatVector(3))
@@ -402,7 +402,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       # test with custom optimizer, loss
 
       class CustomOp(keras.optimizers.RMSprop):
@@ -438,7 +438,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.layers.Input(shape=(3,))
       x = keras.layers.Dense(2)(inputs)
       output = keras.layers.Dense(3)(x)
@@ -474,7 +474,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.Dense(3))
@@ -490,7 +490,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.Dense(3))
@@ -508,7 +508,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,)))
       model.add(keras.layers.Dense(3))
@@ -522,7 +522,7 @@
       os.remove(fname)
 
   def test_saving_lambda_numpy_array_arguments(self):
-    with self.test_session():
+    with self.cached_session():
       if h5py is None:
         self.skipTest('h5py required to run this test')
 
@@ -548,7 +548,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       # This layer name will make the `layers_name` HDF5 attribute blow
       # out of proportion. Note that it fits into the internal HDF5
       # attribute memory limit on its own but because h5py converts
@@ -589,7 +589,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       x = keras.Input(shape=(2,), name='nested_model_input')
       f = x
       for i in range(4):
@@ -634,7 +634,7 @@
     if h5py is None:
       self.skipTest('h5py required to run this test')
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.Input(shape=(3,))
       x = keras.layers.Dense(2)(inputs)
       outputs = keras.layers.Dense(3)(x)
@@ -703,7 +703,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_tensorflow_format_overwrite(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       model = SubclassedModel()
       temp_dir = self.get_temp_dir()
       prefix = os.path.join(temp_dir, 'ckpt')
@@ -760,7 +760,7 @@
         self.assertEqual(len(graph.get_operations()), op_count)
 
   def _weight_loading_test_template(self, make_model_fn):
-    with self.test_session():
+    with self.cached_session():
       model = make_model_fn()
       model.compile(
           loss='mse',
@@ -822,7 +822,7 @@
 
   def _new_layer_weight_loading_test_template(
       self, first_model_fn, second_model_fn, restore_init_fn):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       model = first_model_fn()
       temp_dir = self.get_temp_dir()
       prefix = os.path.join(temp_dir, 'ckpt')
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 28af8d6..9d615c9 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -132,7 +132,7 @@
 
   @parameterized.parameters((True,), (False,))
   def test_training_and_eval_methods_on_symbolic_tensors(self, deferred):
-    with self.test_session():
+    with self.cached_session():
 
       def get_model():
         if deferred:
@@ -222,7 +222,7 @@
     val_a = np.random.random((10, 4))
     val_out = np.random.random((10, 4))
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.BatchNormalization(input_shape=(4,)))
       assert model.updates
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 079c8da..061db8e 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -342,7 +342,7 @@
     self.assertListEqual(model.non_trainable_weights, weights)
 
   def test_learning_phase(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(32,), name='input_a')
       b = keras.layers.Input(shape=(32,), name='input_b')
 
@@ -458,7 +458,7 @@
     self.assertEqual(dense.get_output_mask_at(1), None)
 
   def test_multi_input_layer(self):
-    with self.test_session():
+    with self.cached_session():
       # test multi-input layer
       a = keras.layers.Input(shape=(32,), name='input_a')
       b = keras.layers.Input(shape=(32,), name='input_b')
@@ -530,7 +530,7 @@
       self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
 
   def test_recursion(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(32,), name='input_a')
       b = keras.layers.Input(shape=(32,), name='input_b')
 
@@ -591,7 +591,7 @@
       self.assertListEqual([x.shape for x in fn_outputs], [(10, 7), (10, 64)])
 
   def test_multi_input_multi_output_recursion(self):
-    with self.test_session():
+    with self.cached_session():
       # test multi-input multi-output
       a = keras.layers.Input(shape=(32,), name='input_a')
       b = keras.layers.Input(shape=(32,), name='input_b')
@@ -816,7 +816,7 @@
     self.assertEqual(loss, 4.)
 
   def test_layer_sharing_at_heterogenous_depth(self):
-    with self.test_session():
+    with self.cached_session():
       x_val = np.random.random((10, 5))
 
       x = input_layer_lib.Input(shape=(5,))
@@ -837,7 +837,7 @@
       self.assertAllClose(output_val, output_val_2, atol=1e-6)
 
   def test_layer_sharing_at_heterogenous_depth_with_concat(self):
-    with self.test_session():
+    with self.cached_session():
       input_shape = (16, 9, 3)
       input_layer = input_layer_lib.Input(shape=input_shape)
 
@@ -864,7 +864,7 @@
       self.assertAllClose(output_val, output_val_2, atol=1e-6)
 
   def test_explicit_training_argument(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(2,))
       b = keras.layers.Dropout(0.5)(a)
       base_model = keras.models.Model(a, b)
@@ -887,7 +887,8 @@
 
   def test_multi_output_model_with_none_masking(self):
 
-    with self.test_session():
+    with self.cached_session():
+
       def func(x):
         return [x * 0.2, x * 0.3]
 
@@ -912,6 +913,23 @@
       assert out.shape == (4, 3, 2, 1)
       self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4)
 
+  def test_constant_initializer_with_numpy(self):
+
+    with self.test_session():
+      initializer = keras.initializers.Constant(np.ones((3, 2)))
+      model = keras.models.Sequential()
+      model.add(keras.layers.Dense(2, input_shape=(3,),
+                                   kernel_initializer=initializer))
+      model.add(keras.layers.Dense(3))
+      model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
+
+      json_str = model.to_json()
+      keras.models.model_from_json(json_str)
+
+      if yaml is not None:
+        yaml_str = model.to_yaml()
+        keras.models.model_from_yaml(yaml_str)
+
 
 class DeferredModeTest(test.TestCase):
 
@@ -1169,7 +1187,7 @@
 
   def testGetReachableFromInputs(self):
 
-    with self.test_session():
+    with self.cached_session():
       pl_1 = array_ops.placeholder(shape=None, dtype='float32')
       pl_2 = array_ops.placeholder(shape=None, dtype='float32')
       pl_3 = array_ops.placeholder(shape=None, dtype='float32')
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 85d2541..c674946 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,9 +20,11 @@
 
 import weakref
 import numpy as np
+import six
 
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
 from tensorflow.python.eager import context
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -39,12 +41,14 @@
 from tensorflow.python.keras.engine import training_generator
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.keras.engine.network import Network
+from tensorflow.python.keras.utils import data_utils
 from tensorflow.python.keras.utils.generic_utils import slice_arrays
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import weights_broadcast_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import optimizer as tf_optimizer_module
 from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -405,20 +409,9 @@
     # Set DistributionStrategy specific parameters.
     self._distribution_strategy = distribute
     if self._distribution_strategy is not None:
-      self._grouped_model = self._compile_distributed_model(
+      self._grouped_model = None
+      distributed_training_utils.configure_and_create_session(
           self._distribution_strategy)
-      with self._distribution_strategy.scope():
-        first_replicated_model = self._distribution_strategy.unwrap(
-            self._grouped_model)[0]
-        # If the specified metrics in `compile` are stateful, raise an error
-        # since we currently don't support stateful metrics.
-        if first_replicated_model.stateful_metric_names:
-          raise NotImplementedError('Stateful metrics are not supported with '
-                                    'DistributionStrategy.')
-
-      # We initialize the callback model with the first replicated model.
-      self._replicated_model = DistributedCallbackModel(first_replicated_model)
-      self._replicated_model.set_original_model(self)
     if not self.built:
       # Model is not compilable because it does not know its number of inputs
       # and outputs, nor their shapes and names. We will compile after the first
@@ -636,6 +629,12 @@
         skip_target_indices=skip_target_indices,
         sample_weights=self.sample_weights)
 
+    # If using distribution strategy and stateful_metrics, raise an error
+    # since we currently don't support stateful metrics.
+    if self._distribution_strategy is not None and self.stateful_metric_names:
+      raise NotImplementedError('Stateful metrics are not supported with '
+                                'DistributionStrategy.')
+
     # Prepare gradient updates and state updates.
     self.total_loss = total_loss
 
@@ -652,19 +651,6 @@
     trainable_weights = self.trainable_weights
     self._collected_trainable_weights = trainable_weights
 
-  def _compile_distributed_model(self, distribution_strategy):
-    # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the
-    # model?
-    def _clone_model_per_tower(model):
-      new_model = training_distributed.clone_and_build_model(model)
-      return new_model
-
-    with distribution_strategy.scope():
-      # Create a copy of this model on each of the devices.
-      grouped_models = distribution_strategy.call_for_each_tower(
-          _clone_model_per_tower, self)
-    return grouped_models
-
   def _check_trainable_weights_consistency(self):
     """Check trainable weights count consistency.
 
@@ -771,9 +757,8 @@
     the model.
 
     Args:
-      x: Input data. A `tf.data` dataset.
-      y: Since `x` is a dataset, `y` should not be specified
-        (since targets will be obtained from the iterator).
+      x: Input data. A numpy array or `tf.data` dataset.
+      y: Target data. A numpy array or None if x is a `tf.data` dataset.
       sample_weight: An optional sample-weight array passed by the user to
         weight the importance of each sample in `x`.
       class_weight: An optional class-weight array by the user to
@@ -790,10 +775,7 @@
         Fraction of the training data to be used as validation data.
 
     Returns:
-      A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
-      If the model's input and targets are symbolic, these lists are empty
-      (since the model takes no user-provided data, instead the data comes
-      from the symbolic inputs/targets).
+      Iterator for reading the dataset `x`.
 
     Raises:
       ValueError: In case of invalid user-provided data.
@@ -806,12 +788,51 @@
       raise NotImplementedError('`class_weight` is currently not supported '
                                 'when using DistributionStrategy.')
 
+    # Validates `steps` argument right at the beginning since we use it to
+    # construct the dataset object.
+    # TODO(anjalisridhar): This may not be a valid error since we now accept
+    # numpy array inputs. We still want to assert that we have a populated steps
+    # parameter.
+    if check_steps:
+      if steps is None:
+        raise ValueError('When using DistributionStrategy, '
+                         'you should specify the `{steps_name}` argument.'
+                         .format(steps_name=steps_name))
+
+    first_x_value = nest.flatten(x)[0]
+    if isinstance(first_x_value, np.ndarray):
+      x_shape = first_x_value.shape
+      x_dtype = first_x_value.dtype
+      if batch_size is None:
+        batch_size = x_shape[0] // steps
+      if y is not None:
+        first_y_value = nest.flatten(y)[0]
+        x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
+                                   output_types=(x_dtype, first_y_value.dtype),
+                                   output_shapes=(x_shape[1:],
+                                                  first_y_value.shape[1:]))
+        # TODO(anjalisridhar): What should the buffer size be?
+        x = x.shuffle(10000)
+        x = x.repeat()
+        x = x.batch(batch_size)
+        y = None
+      else:
+        # This case is for the predict call where the dataset only contains
+        # inputs and no targets i.e it does not return a tuple.
+        # TODO(anjalisridhar): Raise an error if we are not able to process
+        # all the predict samples. This can happen if the number of batches is
+        # not evenly divisible by the number of worker devices.
+        x = Dataset.from_generator(lambda x=x: x,
+                                   output_types=x_dtype,
+                                   output_shapes=x_shape[1:])
+        x = x.repeat()
+        x = x.batch(batch_size)
+
     # TODO(anjalisridhar): Can we use the iterator and getnext op cache?
     # We require users to pass Datasets since we distribute the dataset across
     # multiple devices.
-    if not isinstance(x, dataset_ops.Dataset):
-      raise ValueError('When using DistributionStrategy, model inputs should be'
-                       ' Dataset instances; found instead %s.' % type(x))
+    assert isinstance(x, dataset_ops.Dataset)
+
     # TODO(anjalisridhar): We want distribute_dataset() to accept a Dataset or a
     # function which returns a Dataset. Currently distribute_dataset() only
     # accepts a function that returns a Dataset. Once we add support for being
@@ -819,39 +840,10 @@
     result = self._distribution_strategy.distribute_dataset(lambda: x)
     iterator = result.make_initializable_iterator()
     K.get_session().run(iterator.initializer)
-    # Validates `steps` argument based on x's type.
-    if check_steps:
-      if steps is None:
-        raise ValueError('When using a Dataset instance as input to a model, '
-                         'you should specify the `{steps_name}` argument.'
-                         .format(steps_name=steps_name))
 
     training_utils.validate_iterator_input(x, y, sample_weight,
                                            validation_split)
-    # x an y may be PerDevice objects with an input and output tensor
-    # corresponding to each device. For example, x could be
-    # PerDevice:{device: get_next tensor,...}.
-    next_element = iterator.get_next()
-
-    if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
-      raise ValueError('Please provide model inputs as a list or tuple of 2 '
-                       'elements: input and target pair. '
-                       'Received %s' % next_element)
-    x, y = next_element
-    # Validate that all the elements in x and y are of the same type and shape.
-    # We can then pass the first element of x and y to `_standardize_weights`
-    # below and be confident of the output. We need to reopen the scope since
-    # we unwrap values when we validate x and y.
-    with self._distribution_strategy.scope():
-      x_values, y_values = distributed_training_utils.\
-        validate_distributed_dataset_inputs(self._distribution_strategy, x, y)
-
-    _, _, sample_weights = self._standardize_weights(x_values,
-                                                     y_values,
-                                                     sample_weight,
-                                                     class_weight,
-                                                     batch_size)
-    return x, y, sample_weights
+    return iterator
 
   def _standardize_user_data(self,
                              x,
@@ -906,7 +898,8 @@
         Fraction of the training data to be used as validation data.
 
     Returns:
-      A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
+      A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
+      or not), target arrays, sample-weight arrays.
       If the model's input and targets are symbolic, these lists are empty
       (since the model takes no user-provided data, instead the data comes
       from the symbolic inputs/targets).
@@ -916,7 +909,7 @@
       RuntimeError: If the model was never compiled.
     """
     if self._distribution_strategy:
-      return self._distribution_standardize_user_data(
+      iterator = self._distribution_standardize_user_data(
           x,
           y,
           sample_weight=sample_weight,
@@ -926,6 +919,7 @@
           steps_name=steps_name,
           steps=steps,
           validation_split=validation_split)
+      return iterator, None, None
 
     if isinstance(x, dataset_ops.Dataset):
       if context.executing_eagerly():
@@ -971,17 +965,23 @@
                            'Make sure that your dataset can generate '
                            'required number of samples.')
 
-      if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
-        raise ValueError('Please provide model inputs as a list or tuple of 2 '
-                         'elements: input and target pair. '
-                         'Received %s' % next_element)
-      x, y = next_element
+      if (not isinstance(next_element, (list, tuple)) or
+          len(next_element) not in [2, 3]):
+        raise ValueError(
+            'Please provide model inputs as a list or tuple of 2  or 3'
+            'elements: (input, target) or (input, target, sample_weights)'
+            'Received %s' % next_element)
+      if len(next_element) == 2:
+        x, y = next_element
+      else:
+        x, y, sample_weight = next_element
     x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
                                                      class_weight, batch_size)
     return x, y, sample_weights
 
   def _standardize_weights(self, x, y, sample_weight=None, class_weight=None,
                            batch_size=None,):
+    # TODO(sourabhbajaj): Split input validation from weight standardization.
     if sample_weight is not None and class_weight is not None:
       logging.warning(
           'Received both a `sample_weight` and `class_weight` argument. '
@@ -990,6 +990,7 @@
     all_inputs = []
     is_build_called = False
     is_compile_called = False
+    dict_inputs = False
     if not self.inputs:
       # We need to use `x` to set the model inputs.
       # We type-check that `x` and `y` are either single arrays
@@ -1001,7 +1002,9 @@
                            'array or a list of arrays. You passed: x=' + str(x))
         all_inputs += list(x)
       elif isinstance(x, dict):
-        raise ValueError('Please do not pass a dictionary as model inputs.')
+        dict_inputs = True
+        keys = sorted(x.keys())
+        all_inputs = [x[k] for k in keys]
       else:
         if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x):
           raise ValueError('Please provide as model inputs either a single '
@@ -1014,6 +1017,8 @@
       if not self.inputs:
         is_build_called = True
         self._set_inputs(x)
+    else:
+      dict_inputs = isinstance(self.inputs, dict)
 
     if y is not None:
       if not self.optimizer:
@@ -1166,6 +1171,10 @@
                          'a number of samples that can be '
                          'divided by the batch size. Found: ' +
                          str(x[0].shape[0]) + ' samples')
+
+    # If dictionary inputs were provided, we return a dictionary as well.
+    if dict_inputs:
+      x = dict(zip(feed_input_names, x))
     return x, y, sample_weights
 
   @checkpointable.no_automatic_dependency_tracking
@@ -1188,6 +1197,9 @@
       training: Boolean or None. Only relevant in symbolic mode. Specifies
         whether to build the model's graph in inference mode (False), training
         mode (True), or using the Keras learning phase (None).
+    Raises:
+      ValueError: If dict inputs are passed to a Sequential Model where the
+        first layer isn't FeatureLayer.
     """
     call_convention = getattr(
         self,
@@ -1204,6 +1216,14 @@
       if tensor_util.is_tensor(inputs):
         input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
         self.build(input_shape=input_shape)
+      elif isinstance(inputs, dict):
+        # We assert that the first layer is a FeatureLayer.
+        if not training_utils.is_feature_layer(self.layers[0]):
+          raise ValueError('Passing a dictionary input to a Sequential Model '
+                           'which doesnt have FeatureLayer as the first layer '
+                           'is an error')
+        input_shape = (None,)
+        self.build(input_shape=input_shape)
       else:
         input_shape = (None,) + inputs.shape[1:]
         self.build(input_shape=input_shape)
@@ -1231,36 +1251,22 @@
     assert context.executing_eagerly()
     if self.inputs:
       raise ValueError('Model inputs are already set.')
+
     # On-the-fly setting of model inputs/outputs as DeferredTensors,
     # to keep track of number of inputs and outputs and their ndim.
-    if isinstance(inputs, (list, tuple)):
-      if tensor_util.is_tensor(inputs[0]):
-        dummy_output_values = self.call(
-            training_utils.cast_if_floating_dtype(inputs))
-      else:
-        dummy_output_values = self.call(
-            [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
-      dummy_input_values = list(inputs)
-    else:
-      if tensor_util.is_tensor(inputs):
-        dummy_output_values = self.call(
-            training_utils.cast_if_floating_dtype(inputs))
-      else:
-        dummy_output_values = self.call(
-            ops.convert_to_tensor(inputs, dtype=K.floatx()))
-      dummy_input_values = [inputs]
-    if isinstance(dummy_output_values, (list, tuple)):
-      dummy_output_values = list(dummy_output_values)
-    else:
-      dummy_output_values = [dummy_output_values]
+    model_inputs = training_utils.ModelInputs(inputs)
+    dummy_input_values = model_inputs.get_input_values()
+    dummy_output_values = self.call(dummy_input_values)
+
+    self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+    self.input_names = model_inputs.get_input_names()
+
+    dummy_output_values = nest.flatten(dummy_output_values)
     self.outputs = [
-        base_layer.DeferredTensor(shape=(None for _ in v.shape),
-                                  dtype=v.dtype) for v in dummy_output_values]
-    self.inputs = [
-        base_layer.DeferredTensor(shape=(None for _ in v.shape),
-                                  dtype=v.dtype) for v in dummy_input_values]
-    self.input_names = [
-        'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
+        base_layer.DeferredTensor(shape=(None
+                                         for _ in v.shape), dtype=v.dtype)
+        for v in dummy_output_values
+    ]
     self.output_names = [
         'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
     self.built = True
@@ -1290,58 +1296,29 @@
 
     # On-the-fly setting of symbolic model inputs (either by using the tensor
     # provided, or by creating a placeholder if Numpy data was provided).
-    self.inputs = []
-    self.input_names = []
+    model_inputs = training_utils.ModelInputs(inputs)
+    dummy_input_values = model_inputs.get_symbolic_inputs()
+    self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+    self.input_names = model_inputs.get_input_names()
+
     self._feed_inputs = []
     self._feed_input_names = []
     self._feed_input_shapes = []
-    if isinstance(inputs, (list, tuple)):
-      inputs = list(inputs)
-    else:
-      inputs = [inputs]
 
-    for i, v in enumerate(inputs):
-      name = 'input_%d' % (i + 1)
-      self.input_names.append(name)
-      if isinstance(v, list):
-        v = np.asarray(v)
-        if v.ndim == 1:
-          v = np.expand_dims(v, 1)
-      if isinstance(v, (np.ndarray)):
-        # We fix the placeholder shape except the batch size.
-        # This is suboptimal, but it is the best we can do with the info
-        # we have. The user should call `model._set_inputs(placeholders)`
-        # to specify custom placeholders if the need arises.
-        shape = (None,) + v.shape[1:]
-        placeholder = K.placeholder(shape=shape, name=name)
-        self.inputs.append(placeholder)
-        self._feed_inputs.append(placeholder)
-        self._feed_input_names.append(name)
-        self._feed_input_shapes.append(shape)
-      else:
-        # Assumed tensor - TODO(fchollet) additional type check?
-        self.inputs.append(v)
-        if K.is_placeholder(v):
-          self._feed_inputs.append(v)
-          self._feed_input_names.append(name)
-          self._feed_input_shapes.append(K.int_shape(v))
+    for k, v in model_inputs.as_dict():
+      if K.is_placeholder(v):
+        self._feed_inputs.append(v)
+        self._feed_input_names.append(k)
+        self._feed_input_shapes.append(K.int_shape(v))
 
     if outputs is None:
       # Obtain symbolic outputs by calling the model.
-      if len(self.inputs) == 1:
-        if self._expects_training_arg:
-          outputs = self.call(self.inputs[0], training=training)
-        else:
-          outputs = self.call(self.inputs[0])
+      if self._expects_training_arg:
+        outputs = self.call(dummy_input_values, training=training)
       else:
-        if self._expects_training_arg:
-          outputs = self.call(self.inputs, training=training)
-        else:
-          outputs = self.call(self.inputs)
-    if isinstance(outputs, (list, tuple)):
-      outputs = list(outputs)
-    else:
-      outputs = [outputs]
+        outputs = self.call(dummy_input_values)
+
+    outputs = nest.flatten(outputs)
     self.outputs = outputs
     self.output_names = [
         'output_%d' % (i + 1) for i in range(len(self.outputs))]
@@ -1362,6 +1339,9 @@
           initial_epoch=0,
           steps_per_epoch=None,
           validation_steps=None,
+          max_queue_size=10,
+          workers=1,
+          use_multiprocessing=False,
           **kwargs):
     """Trains the model for a fixed number of epochs (iterations on a dataset).
 
@@ -1373,19 +1353,24 @@
             (in case the model has multiple inputs).
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
-          - A `tf.data` dataset or a dataset iterator.
+          - A `tf.data` dataset or a dataset iterator. Should return a tuple
+            of either `(inputs, targets)` or
+            `(inputs, targets, sample_weights)`.
+          - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
+            or `(inputs, targets, sample weights)`.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
-          tensor targets, or inversely). If `x` is a dataset or dataset
-          iterator, `y` should not be specified
-          (since targets will be obtained from the iterator).
+          tensor targets, or inversely). If `x` is a dataset, dataset
+          iterator, generator, or `keras.utils.Sequence` instance, `y` should
+          not be specified (since targets will be obtained from `x`).
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` if your data is in the
-            form of symbolic tensors, datasets, or dataset iterators
-            (since they generate batches).
+            form of symbolic tensors, dataset, dataset iterators,
+            generators, or `keras.utils.Sequence` instances (since they generate
+            batches).
         epochs: Integer. Number of epochs to train the model.
             An epoch is an iteration over the entire `x` and `y`
             data provided.
@@ -1407,7 +1392,8 @@
             on this data at the end of each epoch.
             The validation data is selected from the last samples
             in the `x` and `y` data provided, before shuffling. This argument is
-            not supported when `x` is a dataset or a dataset iterator.
+            not supported when `x` is a dataset, dataset iterator, generator or
+           `keras.utils.Sequence` instance.
         validation_data: Data on which to evaluate
             the loss and any model metrics at the end of each epoch.
             The model will not be trained on this data.
@@ -1438,7 +1424,9 @@
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             `sample_weight_mode="temporal"` in `compile()`. This argument is not
-            supported when `x` is a dataset or a dataset iterator.
+            supported when `x` is a dataset, dataset iterator, generator, or
+           `keras.utils.Sequence` instance, instead provide the sample_weights
+            as the third element of `x`.
         initial_epoch: Integer.
             Epoch at which to start training
             (useful for resuming a previous training run).
@@ -1452,6 +1440,20 @@
         validation_steps: Only relevant if `steps_per_epoch`
             is specified. Total number of steps (batches of samples)
             to validate before stopping.
+        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+            input only. Maximum size for the generator queue.
+            If unspecified, `max_queue_size` will default to 10.
+        workers: Integer. Used for generator or `keras.utils.Sequence` input
+            only. Maximum number of processes to spin up
+            when using process-based threading. If unspecified, `workers`
+            will default to 1. If 0, will execute the generator on the main
+            thread.
+        use_multiprocessing: Boolean. Used for generator or
+            `keras.utils.Sequence` input only. If `True`, use process-based
+            threading. If unspecified, `use_multiprocessing` will default to
+            `False`. Note that because this implementation relies on
+            multiprocessing, you should not pass non-picklable arguments to
+            the generator as they can't be passed easily to children processes.
         **kwargs: Used for backwards compatibility.
 
     Returns:
@@ -1468,6 +1470,23 @@
     # TODO(fchollet): this method may be creating reference cycles, which would
     # lead to accumulating garbage in memory when called in a loop. Investigate.
 
+    if data_utils.is_generator_or_sequence(x):
+      training_utils.check_generator_arguments(y, sample_weight)
+      return self.fit_generator(
+          x,
+          steps_per_epoch=steps_per_epoch,
+          epochs=epochs,
+          verbose=verbose,
+          callbacks=callbacks,
+          validation_data=validation_data,
+          validation_steps=validation_steps,
+          class_weight=class_weight,
+          max_queue_size=max_queue_size,
+          workers=workers,
+          use_multiprocessing=use_multiprocessing,
+          shuffle=shuffle,
+          initial_epoch=initial_epoch)
+
     # Backwards compatibility
     if batch_size is None and steps_per_epoch is None:
       batch_size = 32
@@ -1484,6 +1503,13 @@
     if self._distribution_strategy:
       distributed_training_utils.validate_callbacks(callbacks)
 
+      distributed_training_utils.validate_inputs(x, y)
+
+      first_x_value = nest.flatten(x)[0]
+      if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
+        steps_per_epoch = distributed_training_utils.get_input_batch_params(
+            first_x_value, batch_size, self._distribution_strategy)
+
     x, y, sample_weights = self._standardize_user_data(
         x,
         y,
@@ -1518,6 +1544,13 @@
             'However we received `validation_data=%s`' % validation_data)
 
       # Validate and standardize validation data.
+      if self._distribution_strategy:
+        distributed_training_utils.validate_inputs(val_x, val_y)
+        first_valx_value = nest.flatten(val_x)[0]
+        if not validation_steps and isinstance(first_valx_value, np.ndarray):
+          validation_steps = distributed_training_utils.get_input_batch_params(
+              first_valx_value, batch_size, self._distribution_strategy)
+
       val_x, val_y, val_sample_weights = self._standardize_user_data(
           val_x,
           val_y,
@@ -1566,12 +1599,11 @@
           validation_steps=validation_steps)
     elif self._distribution_strategy:
       return training_distributed.fit_loop(
-          self, x, y,
+          self, x,
           epochs=epochs,
           verbose=verbose,
           callbacks=callbacks,
-          val_inputs=val_x,
-          val_targets=val_y,
+          val_iterator=val_x,
           initial_epoch=initial_epoch,
           steps_per_epoch=steps_per_epoch,
           validation_steps=validation_steps)
@@ -1597,7 +1629,10 @@
                batch_size=None,
                verbose=1,
                sample_weight=None,
-               steps=None):
+               steps=None,
+               max_queue_size=10,
+               workers=1,
+               use_multiprocessing=False):
     """Returns the loss value & metrics values for the model in test mode.
 
     Computation is done in batches.
@@ -1611,18 +1646,21 @@
           - A dict mapping input names to the corresponding array/tensors,
             if the model has named inputs.
           - A `tf.data` dataset or a dataset iterator.
+          - A generator or `keras.utils.Sequence` instance.
         y: Target data. Like the input data `x`,
           it could be either Numpy array(s) or TensorFlow tensor(s).
           It should be consistent with `x` (you cannot have Numpy inputs and
           tensor targets, or inversely).
-          If `x` is a dataset or a dataset iterator, `y` should not be specified
-          (since targets will be obtained from the iterator/dataset).
+          If `x` is a dataset, dataset iterator, generator or
+          `keras.utils.Sequence` instance, `y` should not be specified (since
+          targets will be obtained from the iterator/dataset).
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` is your data is in the
-            form of symbolic tensors, datasets, or dataset iterators
-            (since they generate batches).
+            form of symbolic tensors, dataset, dataset iterators,
+            generators, or `keras.utils.Sequence` instances (since they generate
+            batches).
         verbose: 0 or 1. Verbosity mode.
             0 = silent, 1 = progress bar.
         sample_weight: Optional Numpy array of weights for
@@ -1636,11 +1674,25 @@
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             `sample_weight_mode="temporal"` in `compile()`. This argument is not
-            supported when `x` is a dataset or a dataset iterator.
+            supported when `x` is a dataset or a dataset iterator, instead pass
+            sample weights as the third element of `x`.
         steps: Integer or `None`.
             Total number of steps (batches of samples)
             before declaring the evaluation round finished.
             Ignored with the default value of `None`.
+        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+            input only. Maximum size for the generator queue.
+            If unspecified, `max_queue_size` will default to 10.
+        workers: Integer. Used for generator or `keras.utils.Sequence` input
+            only. Maximum number of processes to spin up when using
+            process-based threading. If unspecified, `workers` will default
+            to 1. If 0, will execute the generator on the main thread.
+        use_multiprocessing: Boolean. Used for generator or
+            `keras.utils.Sequence` input only. If `True`, use process-based
+            threading. If unspecified, `use_multiprocessing` will default to
+            `False`. Note that because this implementation relies on
+            multiprocessing, you should not pass non-picklable arguments to
+            the generator as they can't be passed easily to children processes.
 
     Returns:
         Scalar test loss (if the model has a single output and no metrics)
@@ -1651,11 +1703,28 @@
     Raises:
         ValueError: in case of invalid arguments.
     """
+    if data_utils.is_generator_or_sequence(x):
+      training_utils.check_generator_arguments(y, sample_weight)
+      return self.evaluate_generator(
+          x,
+          steps=steps,
+          verbose=verbose,
+          max_queue_size=max_queue_size,
+          workers=workers,
+          use_multiprocessing=use_multiprocessing)
+
     # Backwards compatibility.
     if batch_size is None and steps is None:
       batch_size = 32
 
     # Validate and standardize user data.
+    if self._distribution_strategy:
+      distributed_training_utils.validate_inputs(x, y)
+      first_x_value = nest.flatten(x)[0]
+      if isinstance(first_x_value, np.ndarray) and not steps:
+        steps = distributed_training_utils.get_input_batch_params(
+            first_x_value, batch_size, self._distribution_strategy)
+
     x, y, sample_weights = self._standardize_user_data(
         x,
         y,
@@ -1677,8 +1746,7 @@
     elif self._distribution_strategy:
       return training_distributed.test_loop(
           self,
-          inputs=x,
-          targets=y,
+          iterator=x,
           verbose=verbose,
           steps=steps)
     else:
@@ -1691,7 +1759,14 @@
           verbose=verbose,
           steps=steps)
 
-  def predict(self, x, batch_size=None, verbose=0, steps=None):
+  def predict(self,
+              x,
+              batch_size=None,
+              verbose=0,
+              steps=None,
+              max_queue_size=10,
+              workers=1,
+              use_multiprocessing=False):
     """Generates output predictions for the input samples.
 
     Computation is done in batches.
@@ -1703,16 +1778,32 @@
           - A TensorFlow tensor, or a list of tensors
             (in case the model has multiple inputs).
           - A `tf.data` dataset or a dataset iterator.
+          - A generator or `keras.utils.Sequence` instance.
         batch_size: Integer or `None`.
             Number of samples per gradient update.
             If unspecified, `batch_size` will default to 32.
             Do not specify the `batch_size` is your data is in the
-            form of symbolic tensors, dataset, or dataset iterators
-            (since they generate batches).
+            form of symbolic tensors, dataset, dataset iterators,
+            generators, or `keras.utils.Sequence` instances (since they generate
+            batches).
         verbose: Verbosity mode, 0 or 1.
         steps: Total number of steps (batches of samples)
             before declaring the prediction round finished.
             Ignored with the default value of `None`.
+        max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+            input only. Maximum size for the generator queue.
+            If unspecified, `max_queue_size` will default to 10.
+        workers: Integer. Used for generator or `keras.utils.Sequence` input
+            only. Maximum number of processes to spin up when using
+            process-based threading. If unspecified, `workers` will default
+            to 1. If 0, will execute the generator on the main thread.
+        use_multiprocessing: Boolean. Used for generator or
+            `keras.utils.Sequence` input only. If `True`, use process-based
+            threading. If unspecified, `use_multiprocessing` will default to
+            `False`. Note that because this implementation relies on
+            multiprocessing, you should not pass non-picklable arguments to
+            the generator as they can't be passed easily to children processes.
+
 
     Returns:
         Numpy array(s) of predictions.
@@ -1723,18 +1814,35 @@
             or in case a stateful model receives a number of samples
             that is not a multiple of the batch size.
     """
+    if data_utils.is_generator_or_sequence(x):
+      return self.predict_generator(
+          x,
+          steps=steps,
+          verbose=verbose,
+          max_queue_size=max_queue_size,
+          workers=workers,
+          use_multiprocessing=use_multiprocessing)
+
     # Backwards compatibility.
     if batch_size is None and steps is None:
       batch_size = 32
 
-    # Turn off prefetching since this is currently not deterministic. Once
-    # b/112498930 is fixed we can turn it back on.
-    # `_prefetch_on_device` is currently a property of only `MirroredStrategy`.
-    if (self._distribution_strategy and
-        hasattr(self._distribution_strategy, '_prefetch_on_device')):
-      self._distribution_strategy._prefetch_on_device = False  # pylint: disable=protected-access
+    if self._distribution_strategy:
+      # Turn off prefetching since this is currently not deterministic. Once
+      # b/112498930 is fixed we can turn it back on.
+      # `_prefetch_on_device` is currently a property of only
+      # `MirroredStrategy`.
+      if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+        self._distribution_strategy._prefetch_on_device = False  # pylint: disable=protected-access
+      distributed_training_utils.validate_inputs(x, None)
+      first_x_value = nest.flatten(x)[0]
+      if isinstance(first_x_value, np.ndarray) and not steps:
+        steps = distributed_training_utils.get_input_batch_params(
+            first_x_value, batch_size, self._distribution_strategy)
 
     # Validate and standardize user data.
+    # TODO(anjalisridhar): We don't pass batch_size here for some reason. This
+    # means that we end up calculating it twice which we should avoid.
     x, _, _ = self._standardize_user_data(
         x, check_steps=True, steps_name='steps', steps=steps)
 
@@ -2066,7 +2174,7 @@
     Arguments:
         generator: Generator yielding tuples (inputs, targets)
             or (inputs, targets, sample_weights)
-            or an instance of Sequence (keras.utils.Sequence)
+            or an instance of `keras.utils.Sequence`
             object in order to avoid duplicate data
             when using multiprocessing.
         steps: Total number of steps (batches of samples)
@@ -2130,9 +2238,8 @@
 
     Arguments:
         generator: Generator yielding batches of input samples
-            or an instance of Sequence (keras.utils.Sequence)
-            object in order to avoid duplicate data
-            when using multiprocessing.
+            or an instance of `keras.utils.Sequence` object in order to
+            avoid duplicate data when using multiprocessing.
         steps: Total number of steps (batches of samples)
             to yield from `generator` before stopping.
             Optional for `Sequence`: if unspecified, will use
@@ -2188,6 +2295,13 @@
       return self.callback_model
     return self
 
+  def _make_callback_model(self):
+    first_replicated_model = self._distribution_strategy.unwrap(
+        self._grouped_model)[0]
+    # We initialize the callback model with the first replicated model.
+    self._replicated_model = DistributedCallbackModel(first_replicated_model)
+    self._replicated_model.set_original_model(self)
+
 
 class DistributedCallbackModel(Model):
   """Model that is used for callbacks with DistributionStrategy."""
@@ -2225,6 +2339,6 @@
     # Whitelisted atttributes of the model that can be accessed by the user
     # during a callback.
     if item not in ['_setattr_tracking']:
-      logging.warning('You are accessing attribute ' + item + 'of the'
-                      'DistributedCallbackModel that may not have been set'
+      logging.warning('You are accessing attribute ' + item + 'of the '
+                      'DistributedCallbackModel that may not have been set '
                       'correctly.')
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e2c458c..95b864b 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -55,7 +55,7 @@
 
   Arguments:
       model: Keras Model instance.
-      inputs: List of input arrays.
+      inputs: Either a list of arrays or a dictionary.
       targets: List of target arrays.
       sample_weights: Optional list of sample weight arrays.
       batch_size: Integer batch size or None if unknown.
@@ -88,6 +88,7 @@
 
   sample_weights = sample_weights or []
   val_sample_weights = val_sample_weights or []
+  inputs = training_utils.ModelInputs(inputs).as_list()
   if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
     ins = inputs + targets + sample_weights + [1]
   else:
@@ -262,6 +263,7 @@
   model._make_predict_function()
   f = model.predict_function
 
+  inputs = training_utils.ModelInputs(inputs).as_list()
   if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
     ins = inputs + [0]
   else:
@@ -368,6 +370,7 @@
   f = model.test_function
 
   sample_weights = sample_weights or []
+  inputs = training_utils.ModelInputs(inputs).as_list()
   if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
     ins = inputs + targets + sample_weights + [0]
   else:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 85f1d62..53291c3 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -19,38 +19,41 @@
 from __future__ import division
 from __future__ import print_function
 import numpy as np
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import callbacks as cbks
 from tensorflow.python.keras import optimizers
 from tensorflow.python.keras.engine import distributed_training_utils
 from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
+
+
+# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
 
 
 def fit_loop(
     model,
-    inputs,
-    targets,
+    iterator,
     epochs=100,
     verbose=1,
     callbacks=None,
-    val_inputs=None,
-    val_targets=None,
+    val_iterator=None,
     initial_epoch=0,
     steps_per_epoch=None,
     validation_steps=None):
-  """fit function when using DistributionStrategy for training.
+  """Fit loop for training with DistributionStrategy.
 
   Arguments:
       model: Keras Model instance.
-      inputs: List of input arrays.
-      targets: List of target arrays.
+      iterator: Iterator for input data.
       epochs: Number of times to iterate over the data
-      verbose: Verbosity mode, 0, 1 or 2
+      verbose: Integer, Verbosity mode, 0, 1 or 2
       callbacks: List of callbacks to be called during training
-      val_inputs: List of input arrays.
-      val_targets: List of target arrays.
+      val_iterator: Iterator for validation data.
       initial_epoch: Epoch at which to start training
           (useful for resuming a previous training run)
       steps_per_epoch: Total number of steps (batches of samples)
@@ -67,6 +70,16 @@
       ValueError: in case of invalid arguments.
   """
   current_strategy = model._distribution_strategy
+
+  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+  if current_strategy.__class__.__name__ == 'TPUStrategy':
+    return _experimental_fit_loop(
+        model, iterator, epochs, verbose, callbacks, initial_epoch,
+        steps_per_epoch)
+
+  if not model._grouped_model:
+    clone_model_on_towers(model, current_strategy, make_callback_model=True)
+
   def _per_device_train_function(model):
     model._make_train_function()
     return (model.train_function.inputs,
@@ -74,6 +87,7 @@
             model.train_function.updates_op,
             model.train_function.session_kwargs)
 
+  inputs, targets = _get_input_from_iterator(iterator, model)
   with current_strategy.scope():
     # Create train ops on each of the devices when we call
     # `_per_device_train_function`.
@@ -115,11 +129,6 @@
   do_validation = False
   if validation_steps:
     do_validation = True
-    if steps_per_epoch is None:
-      raise ValueError('Can only use `validation_steps` '
-                       'when doing step-wise '
-                       'training, i.e. `steps_per_epoch` '
-                       'must be set.')
 
   # Copy the weights from the original model to each of the replicated models.
   orig_model_weights = model.get_weights()
@@ -139,45 +148,46 @@
       verbose=verbose)
   out_labels = model.metrics_names or []
   callbacks.on_train_begin()
+
+  assert steps_per_epoch is not None
+
   for epoch in range(initial_epoch, epochs):
     callbacks.on_epoch_begin(epoch)
-    if steps_per_epoch is not None:
-      epoch_logs = {}
-      for step_index in range(steps_per_epoch):
-        batch_logs = {'batch': step_index, 'size': 1}
-        callbacks.on_batch_begin(step_index, batch_logs)
-        try:
-          outs = distributed_train_function(ins)
-        except errors.OutOfRangeError:
-          logging.warning('Your dataset iterator ran out of data; '
-                          'interrupting training. Make sure that your dataset '
-                          'can generate at least `steps_per_epoch * epochs` '
-                          'batches (in this case, %d batches).' %
-                          steps_per_epoch * epochs)
-          break
+    epoch_logs = {}
+    for step_index in range(steps_per_epoch):
+      batch_logs = {'batch': step_index, 'size': 1}
+      callbacks.on_batch_begin(step_index, batch_logs)
+      try:
+        outs = distributed_train_function(ins)
+      except errors.OutOfRangeError:
+        logging.warning('Your dataset iterator ran out of data; '
+                        'interrupting training. Make sure that your dataset '
+                        'can generate at least `steps_per_epoch * epochs` '
+                        'batches (in this case, %d batches).' %
+                        steps_per_epoch * epochs)
+        break
 
-        if not isinstance(outs, list):
-          outs = [outs]
+      if not isinstance(outs, list):
+        outs = [outs]
 
-        outs = _aggregate_metrics_across_towers(
-            current_strategy.num_towers, out_labels, outs)
-        for l, o in zip(out_labels, outs):
-          batch_logs[l] = o
-        callbacks.on_batch_end(step_index, batch_logs)
-        if callbacks.model.stop_training:
-          break
-      if do_validation:
-        val_outs = test_loop(
-            model,
-            val_inputs,
-            val_targets,
-            steps=validation_steps,
-            verbose=0)
-        if not isinstance(val_outs, list):
-          val_outs = [val_outs]
-        # Same labels assumed.
-        for l, o in zip(out_labels, val_outs):
-          epoch_logs['val_' + l] = o
+      outs = _aggregate_metrics_across_towers(
+          current_strategy.num_towers, out_labels, outs)
+      for l, o in zip(out_labels, outs):
+        batch_logs[l] = o
+      callbacks.on_batch_end(step_index, batch_logs)
+      if callbacks.model.stop_training:
+        break
+    if do_validation:
+      val_outs = test_loop(
+          model,
+          val_iterator,
+          steps=validation_steps,
+          verbose=0)
+      if not isinstance(val_outs, list):
+        val_outs = [val_outs]
+      # Same labels assumed.
+      for l, o in zip(out_labels, val_outs):
+        epoch_logs['val_' + l] = o
 
     callbacks.on_epoch_end(epoch, epoch_logs)
     if callbacks.model.stop_training:
@@ -192,14 +202,178 @@
   return model.history
 
 
-def test_loop(model, inputs, targets, verbose=0, steps=None):
-  """evaluate method to validate a model that uses DistributionStrategy.
+def _experimental_fit_loop(
+    model,
+    iterator,
+    epochs=100,
+    verbose=1,
+    callbacks=None,
+    initial_epoch=0,
+    steps_per_epoch=None):
+  """Fit loop for training with TPU DistributionStrategy.
 
   Arguments:
       model: Keras Model instance.
-      inputs: List of input arrays.
-      targets: List of target arrays.
-      verbose: verbosity mode.
+      iterator: Iterator that returns inputs and targets
+      epochs: Number of times to iterate over the data
+      verbose: Integer, Verbosity mode, 0, 1 or 2
+      callbacks: List of callbacks to be called during training
+      initial_epoch: Epoch at which to start training
+          (useful for resuming a previous training run)
+      steps_per_epoch: Total number of steps (batches of samples)
+          before declaring one epoch finished and starting the
+          next epoch. Ignored with the default value of `None`.
+
+  Returns:
+      Returns `None`.
+
+  Raises:
+      ValueError: in case of invalid arguments.
+  """
+  current_strategy = model._distribution_strategy
+
+  # TODO(priyag): Add validation that shapes are fully defined for TPU case.
+
+  K.get_session().run(current_strategy.initialize())
+
+  def _per_device_train_function(model):
+    model._make_train_function()
+    return (model.train_function.inputs,
+            model.train_function.outputs,
+            model.train_function.updates_op,
+            model.train_function.session_kwargs)
+
+  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+  K.set_learning_phase(1)
+
+  def step_fn(ctx, inputs, targets):
+    """Clones the model and calls make_train_function."""
+    # TODO(priyag, sourabhbajaj): The model gets cloned every time
+    # fit/test/predict is called. We should look into caching this keyed on
+    # input shapes.
+    clone_model_on_towers(
+        model,
+        current_strategy,
+        make_callback_model=True,
+        inputs=inputs,
+        targets=targets)
+
+    (grouped_inputs, grouped_outputs, grouped_updates,
+     grouped_session_args) = current_strategy.call_for_each_tower(
+         _per_device_train_function, model._grouped_model)
+    (all_inputs, all_outputs, all_updates,
+     all_session_args) = distributed_training_utils.unwrap_values(
+         current_strategy, grouped_inputs, grouped_outputs,
+         grouped_updates, grouped_session_args)
+    combined_fn = K.Function(
+        all_inputs, all_outputs,
+        updates=all_updates,
+        name='distributed_train_function',
+        **all_session_args)
+
+    out_labels = model.metrics_names or []
+    for label, output in zip(out_labels, combined_fn.outputs):
+      if label == 'loss':
+        aggregation = distribute_lib.get_loss_reduction()
+      else:
+        # We aggregate all other metrics using mean for now. This is temporary
+        # workaround until new metrics are in place.
+        aggregation = variable_scope.VariableAggregation.MEAN
+      ctx.set_last_step_output(label, output, aggregation)
+
+    # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn:
+    # feed_dict, session kwargs, run options, run_metadata for now. These should
+    # be handled appropriately
+    return combined_fn.updates_op
+
+  # Add initial dummy values for loss and other metric tensors.
+  initial_loop_values = {}
+  initial_loop_values['loss'] = constant_op.constant(1e7)
+  for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+  with current_strategy.scope():
+    # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
+    # steps_per_epoch and number of epochs.
+    ctx = current_strategy.run_steps_on_dataset(
+        step_fn, iterator, iterations=current_strategy.steps_per_run,
+        initial_loop_values=initial_loop_values)
+
+  train_op = ctx.run_op
+  output_tensors = ctx.last_step_outputs
+
+  # Copy the weights from the original model to each of the replicated models.
+  orig_model_weights = model.get_weights()
+  with current_strategy.scope():
+    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+    distributed_training_utils.set_weights(
+        current_strategy, distributed_model, orig_model_weights)
+
+  assert steps_per_epoch is not None
+
+  # TODO(sourabhbajaj): Convert this into a proper validation function
+  if callbacks:
+    raise NotImplementedError(
+        'Callbacks are not supported with TPUStrategy right now.')
+
+  callbacks = cbks.configure_callbacks(
+      callbacks,
+      model,
+      do_validation=False,
+      val_inputs=None,
+      val_targets=None,
+      epochs=epochs,
+      steps_per_epoch=steps_per_epoch,
+      verbose=verbose)
+  # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
+  # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
+  # TODO(priyag, sourabhbajaj): Add validation.
+  callbacks.on_train_begin()
+  for epoch in range(initial_epoch, epochs):
+    callbacks.on_epoch_begin(epoch)
+    epoch_logs = {}
+    for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
+      # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
+      # and batch_size
+      batch_logs = {'batch': step_index, 'size': 1}
+      callbacks.on_batch_begin(step_index, batch_logs)
+      try:
+        _, outputs = K.get_session().run([train_op, output_tensors])
+      except errors.OutOfRangeError:
+        logging.warning('Your dataset iterator ran out of data; '
+                        'interrupting training. Make sure that your dataset '
+                        'can generate at least `steps_per_epoch * epochs` '
+                        'batches (in this case, %d batches).' %
+                        steps_per_epoch * epochs)
+        break
+
+      batch_logs.update(outputs)
+      callbacks.on_batch_end(step_index, batch_logs)
+      if callbacks.model.stop_training:
+        break
+
+    callbacks.on_epoch_end(epoch, epoch_logs)
+    if callbacks.model.stop_training:
+      break
+  callbacks.on_train_end()
+
+  # Copy the weights back from the replicated model to the original model.
+  with current_strategy.scope():
+    updated_weights = current_strategy.unwrap(
+        model._grouped_model)[0].get_weights()
+    model.set_weights(updated_weights)
+
+  K.get_session().run(current_strategy.finalize())
+  return model.history
+
+
+def test_loop(model, iterator, verbose=0, steps=None):
+  """Test loop for evaluating with DistributionStrategy.
+
+  Arguments:
+      model: Keras Model instance.
+      iterator: Iterator for input data.
+      verbose: Integer, Verbosity mode 0 or 1.
       steps: Total number of steps (batches of samples)
           before declaring predictions finished.
           Ignored with the default value of `None`.
@@ -208,9 +382,17 @@
       Scalar loss (if the model has a single output and no metrics)
       or list of scalars (if the model has multiple outputs
       and/or metrics). The attribute `model.metrics_names` will give you
-      the display labels for the scalar outputs.
+      the display labels for the outputs.
   """
   current_strategy = model._distribution_strategy
+
+  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+  if current_strategy.__class__.__name__ == 'TPUStrategy':
+    return _experimental_test_loop(model, iterator, verbose, steps)
+
+  if not model._grouped_model:
+    clone_model_on_towers(model, current_strategy)
+
   def _per_device_test_function(model):
     model._make_test_function()
     return (model.test_function.inputs,
@@ -218,6 +400,7 @@
             model.test_function.updates_op,
             model.test_function.session_kwargs)
 
+  inputs, targets = _get_input_from_iterator(iterator, model)
   with current_strategy.scope():
     (grouped_inputs, grouped_outputs, grouped_updates,
      grouped_session_args) = current_strategy.call_for_each_tower(
@@ -259,38 +442,149 @@
     distributed_training_utils.set_weights(
         current_strategy, distributed_model, orig_model_weights)
 
-  if steps is not None:
-    for step in range(steps):
-      batch_outs = distributed_test_function(ins)
-      batch_outs = _aggregate_metrics_across_towers(
-          current_strategy.num_towers, model.metrics_names, batch_outs)
-      if isinstance(batch_outs, list):
-        if step == 0:
-          for _ in enumerate(batch_outs):
-            outs.append(0.)
-        for i, batch_out in enumerate(batch_outs):
-          outs[i] += batch_out
-      else:
-        if step == 0:
-          outs.append(0.)
-        outs[0] += batch_outs
-      if verbose == 1:
-        progbar.update(step + 1)
-    for i in range(len(outs)):
-      outs[i] /= steps
+  assert steps is not None
+  for step in range(steps):
+    batch_outs = distributed_test_function(ins)
+    batch_outs = _aggregate_metrics_across_towers(
+        current_strategy.num_towers, model.metrics_names, batch_outs)
+    if isinstance(batch_outs, list):
+      if step == 0:
+        outs = [0.] * len(batch_outs)
+      for i, batch_out in enumerate(batch_outs):
+        outs[i] += batch_out
+    else:
+      if step == 0:
+        outs.append(0.)
+      outs[0] += batch_outs
+    if verbose >= 1:
+      progbar.update(step + 1)
+  for i in range(len(outs)):
+    outs[i] /= steps
 
   if len(outs) == 1:
     return outs[0]
   return outs
 
 
-def predict_loop(model, inputs, verbose=0, steps=None):
-  """Abstract method to loop over some data in batches.
+def _experimental_test_loop(model, iterator, verbose=0, steps=None):
+  """Test loop for evaluating with TPU DistributionStrategy.
 
   Arguments:
       model: Keras Model instance.
-      inputs: list of tensors to be fed to `f`.
-      verbose: verbosity mode.
+      iterator: Iterator for input data.
+      verbose: Integer, Verbosity mode 0 or 1.
+      steps: Total number of steps (batches of samples)
+          before declaring predictions finished.
+          Ignored with the default value of `None`.
+
+  Returns:
+      Scalar loss (if the model has a single output and no metrics)
+      or list of scalars (if the model has multiple outputs
+      and/or metrics). The attribute `model.metrics_names` will give you
+      the display labels for the outputs.
+  """
+  current_strategy = model._distribution_strategy
+  K.get_session().run(current_strategy.initialize())
+
+  def _per_device_test_function(model):
+    model._make_test_function()
+    return (model.test_function.inputs,
+            model.test_function.outputs,
+            model.test_function.updates_op,
+            model.test_function.session_kwargs)
+
+  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+  K.set_learning_phase(0)
+
+  def step_fn(ctx, inputs, targets):
+    """Clones the model and calls make_test_function."""
+    # TODO(priyag, sourabhbajaj): The model gets cloned every time
+    # fit/test/predict is called. We should look into caching this keyed on
+    # input shapes.
+    clone_model_on_towers(
+        model,
+        current_strategy,
+        make_callback_model=False,
+        inputs=inputs,
+        targets=targets)
+
+    (grouped_inputs, grouped_outputs, grouped_updates,
+     grouped_session_args) = current_strategy.call_for_each_tower(
+         _per_device_test_function, model._grouped_model)
+
+    (all_inputs, all_outputs, all_updates,
+     all_session_args) = distributed_training_utils.unwrap_values(
+         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+         grouped_session_args)
+
+    combined_fn = K.Function(
+        all_inputs, all_outputs,
+        updates=all_updates,
+        name='distributed_test_function',
+        **all_session_args)
+
+    for label, output in zip(model.metrics_names, combined_fn.outputs):
+      if label == 'loss':
+        aggregation = distribute_lib.get_loss_reduction()
+      else:
+        # We aggregate all other metrics using mean for now. This is temporary
+        # workaround until new metrics are in place.
+        aggregation = variable_scope.VariableAggregation.MEAN
+      ctx.set_last_step_output(label, output, aggregation)
+
+    return combined_fn.updates_op
+
+  # Add initial dummy values for loss and other metric tensors.
+  initial_loop_values = {}
+  initial_loop_values['loss'] = constant_op.constant(1e7)
+  for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
+    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+  with current_strategy.scope():
+    # TODO(priyag): Use steps_per_run when we use new metrics as they will
+    # allow handling metric computation at each step using variables.
+    ctx = current_strategy.run_steps_on_dataset(
+        step_fn, iterator, iterations=1,
+        initial_loop_values=initial_loop_values)
+
+  test_op = ctx.run_op
+  output_tensors = ctx.last_step_outputs
+
+  if verbose == 1:
+    progbar = Progbar(target=steps)
+
+  # Copy the weights from the original model to each of the replicated models.
+  orig_model_weights = model.get_weights()
+  with current_strategy.scope():
+    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+    distributed_training_utils.set_weights(
+        current_strategy, distributed_model, orig_model_weights)
+
+  assert steps is not None
+  outs = [0.] * len(model.metrics_names)
+  for step in range(steps):
+    _, batch_outs = K.get_session().run([test_op, output_tensors])
+    for i, label in enumerate(model.metrics_names):
+      outs[i] += batch_outs[label]
+    if verbose >= 1:
+      progbar.update(step + 1)
+  for i in range(len(outs)):
+    outs[i] /= (steps)
+
+  K.get_session().run(current_strategy.finalize())
+
+  if len(outs) == 1:
+    return outs[0]
+  return outs
+
+
+def predict_loop(model, iterator, verbose=0, steps=None):
+  """Predict loop for predicting with DistributionStrategy.
+
+  Arguments:
+      model: Keras Model instance.
+      iterator: Iterator for input data.
+      verbose: Integer, Verbosity mode 0 or 1.
       steps: Total number of steps (batches of samples)
           before declaring `_predict_loop` finished.
           Ignored with the default value of `None`.
@@ -301,6 +595,14 @@
       (if the model has multiple outputs).
   """
   current_strategy = model._distribution_strategy
+
+  # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
+  if current_strategy.__class__.__name__ == 'TPUStrategy':
+    return _experimental_predict_loop(model, iterator, verbose, steps)
+
+  if not model._grouped_model:
+    clone_model_on_towers(model, current_strategy)
+
   def _per_device_predict_function(model):
     model._make_predict_function()
     return (model.predict_function.inputs,
@@ -308,6 +610,7 @@
             model.predict_function.updates_op,
             model.predict_function.session_kwargs)
 
+  inputs, _ = _get_input_from_iterator(iterator, model)
   with current_strategy.scope():
     (grouped_inputs, grouped_outputs, grouped_updates,
      grouped_session_args) = current_strategy.call_for_each_tower(
@@ -354,9 +657,11 @@
       if step == 0:
         for _ in batch_outs:
           unconcatenated_outs.append([])
+      # TODO(anjalisridhar): Should combine the outputs from multiple towers
+      # correctly here.
       for i, batch_out in enumerate(batch_outs):
         unconcatenated_outs[i].append(batch_out)
-      if verbose == 1:
+      if verbose >= 1:
         progbar.update(step + 1)
     if len(unconcatenated_outs) == 1:
       return np.concatenate(unconcatenated_outs[0], axis=0)
@@ -366,12 +671,128 @@
     ]
 
 
-def clone_and_build_model(model):
+def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
+  """Predict loop for predicting with TPU DistributionStrategy.
+
+  Arguments:
+      model: Keras Model instance.
+      iterator: Iterator for input data.
+      verbose: Integer, Verbosity mode 0 or 1.
+      steps: Total number of steps (batches of samples)
+          before declaring `_predict_loop` finished.
+          Ignored with the default value of `None`.
+
+  Returns:
+      Array of predictions (if the model has a single output)
+      or list of arrays of predictions
+      (if the model has multiple outputs).
+  """
+  current_strategy = model._distribution_strategy
+  K.get_session().run(current_strategy.initialize())
+
+  # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here.
+  K.set_learning_phase(0)
+
+  def _per_device_predict_function(model):
+    model._make_predict_function()
+    return (model.predict_function.inputs,
+            model.predict_function.outputs,
+            model.predict_function.updates_op,
+            model.predict_function.session_kwargs)
+
+  def step_fn(ctx, inputs, targets):
+    """Clones the model and calls make_predict_function."""
+
+    # TODO(anjalisridhar): Support predict input correctly as it will not
+    # contain targets, only inputs.
+    del targets
+
+    # TODO(priyag, sourabhbajaj): The model gets cloned every time
+    # fit/test/predict is called. We should look into caching this keyed on
+    # input shapes.
+    clone_model_on_towers(
+        model,
+        current_strategy,
+        make_callback_model=False,
+        inputs=inputs)
+
+    (grouped_inputs, grouped_outputs, grouped_updates,
+     grouped_session_args) = current_strategy.call_for_each_tower(
+         _per_device_predict_function, model._grouped_model)
+
+    (all_inputs, all_outputs, all_updates,
+     all_session_args) = distributed_training_utils.unwrap_values(
+         current_strategy, grouped_inputs, grouped_outputs, grouped_updates,
+         grouped_session_args)
+
+    combined_fn = K.Function(
+        all_inputs, all_outputs,
+        updates=all_updates,
+        name='distributed_predict_function',
+        **all_session_args)
+
+    for label, output in zip(model.output_names, combined_fn.outputs):
+      ctx.set_last_step_output(label, output)
+
+    return combined_fn.updates_op
+
+  # Add initial dummy values for outputs.
+  initial_loop_values = {}
+  batch_dimension = distributed_training_utils.get_batch_dimension(iterator)
+  for name, tensor in zip(model.output_names, model.outputs):
+    # TODO(priyag): This is a workaround as we do not know the batch dimension
+    # of the model's output at this point.
+    tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
+    initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+
+  with current_strategy.scope():
+    # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
+    ctx = current_strategy.run_steps_on_dataset(
+        step_fn, iterator, iterations=1,
+        initial_loop_values=initial_loop_values)
+
+  predict_op = ctx.run_op
+  output_tensors = ctx.last_step_outputs
+
+  if verbose == 1:
+    progbar = Progbar(target=steps)
+
+  # Copy the weights from the original model to each of the replicated models.
+  orig_model_weights = model.get_weights()
+  with current_strategy.scope():
+    distributed_model = current_strategy.unwrap(model._grouped_model)[0]
+    distributed_training_utils.set_weights(
+        current_strategy, distributed_model, orig_model_weights)
+
+  assert steps is not None
+  # Since we do not know how many samples we will see, we cannot pre-allocate
+  # the returned Numpy arrays. Instead, we store one array per batch seen
+  # and concatenate them upon returning.
+  unconcatenated_outs = [[] for _ in model.outputs]
+  for step in range(steps):
+    _, batch_outs = K.get_session().run([predict_op, output_tensors])
+    # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
+    for i, label in enumerate(model.output_names):
+      unconcatenated_outs[i].extend(batch_outs[label])
+    if verbose >= 1:
+      progbar.update(step + 1)
+
+  K.get_session().run(current_strategy.finalize())
+
+  if len(unconcatenated_outs) == 1:
+    return np.concatenate(unconcatenated_outs[0], axis=0)
+  return [
+      np.concatenate(unconcatenated_outs[i], axis=0)
+      for i in range(len(unconcatenated_outs))
+  ]
+
+
+def _clone_and_build_model(model, inputs=None, targets=None):
   """Clone and build the given keras_model."""
   # We need to set the import here since we run into a circular dependency
   # error.
   from tensorflow.python.keras import models  # pylint: disable=g-import-not-at-top
-  cloned_model = models.clone_model(model, input_tensors=None)
+  cloned_model = models.clone_model(model, input_tensors=inputs)
 
   # Compile and build model.
   if isinstance(model.optimizer, optimizers.TFOptimizer):
@@ -380,16 +801,32 @@
     optimizer_config = model.optimizer.get_config()
     optimizer = model.optimizer.__class__.from_config(optimizer_config)
 
+  # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
+  # single tensor should be OK but it throws an error in that case.
+  if (targets is not None and not isinstance(targets, list) and
+      not isinstance(targets, dict)):
+    targets = [targets]
   cloned_model.compile(
       optimizer,
       model.loss,
       metrics=model.metrics,
       loss_weights=model.loss_weights,
       sample_weight_mode=model.sample_weight_mode,
-      weighted_metrics=model.weighted_metrics)
+      weighted_metrics=model.weighted_metrics,
+      target_tensors=targets)
   return cloned_model
 
 
+def clone_model_on_towers(
+    model, strategy, make_callback_model=False, inputs=None, targets=None):
+  """Create a cloned model on each tower."""
+  with strategy.scope():
+    model._grouped_model = strategy.call_for_each_tower(
+        _clone_and_build_model, model, inputs, targets)
+  if make_callback_model:
+    model._make_callback_model()
+
+
 def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
   """Aggregate metrics values across all towers.
 
@@ -419,3 +856,23 @@
     merged_output.append(m)
     current_index += num_devices
   return merged_output
+
+
+def _get_input_from_iterator(iterator, model):
+  """Get elements from the iterator and verify the input shape and type."""
+  next_element = iterator.get_next()
+
+  if isinstance(next_element, tuple):
+    x, y = next_element
+  else:
+    x = next_element
+    y = None
+  # Validate that all the elements in x and y are of the same type and shape.
+  # We can then pass the first element of x and y to `_standardize_weights`
+  # below and be confident of the output.
+  x_values, y_values = distributed_training_utils.\
+    validate_distributed_dataset_inputs(model._distribution_strategy, x, y)
+  # TODO(sourabhbajaj): Add support for sample weights in distribution
+  # strategy.
+  model._standardize_weights(x_values, y_values)
+  return x, y
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 1e37714..939a7f2 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -67,7 +67,8 @@
 
   Arguments:
       model: The model on which metrics are being calculated.
-      inputs: List of input arrays.
+      inputs: Either a dictionary of inputs to the model or a list of input
+        arrays.
       targets: List of target arrays.
       sample_weights: Optional list of sample weight arrays.
       training: Whether the model should be run in inference or training mode.
@@ -82,7 +83,7 @@
   kwargs = {}
   if model._expects_training_arg:
     kwargs['training'] = training
-  if len(inputs) == 1:
+  if len(inputs) == 1 and not isinstance(inputs, dict):
     inputs = inputs[0]
 
   if model._compute_output_and_mask_jointly:
@@ -369,6 +370,8 @@
     # Get current step size.
     if isinstance(x, list):
       step_size = x[0].get_shape().as_list()[0]
+    elif isinstance(x, dict):
+      step_size = list(x.values())[0].get_shape().as_list()[0]
     else:
       step_size = x.get_shape().as_list()[0]
 
@@ -417,11 +420,12 @@
   """
   assert isinstance(inputs, iterator_ops.EagerIterator)
   if not isinstance(inputs.output_shapes,
-                    (list, tuple)) or len(inputs.output_shapes) > 2:
+                    (list, tuple)) or len(inputs.output_shapes) > 3:
     raise ValueError(
-        'Please provide data as a list or tuple of 1 or 2 elements '
-        ' - input or input and target pair. Received %s. We do not use the '
-        '`target` value here.' % inputs.output_shapes)
+        'Please provide data as a list or tuple of 1, 2, or 3 elements '
+        ' - `(input)`, or `(input, target)`, or `(input, target,'
+        'sample_weights)`. Received %s. We do not use the `target` or'
+        '`sample_weights` value here.' % inputs.output_shapes)
   outs = []
   if verbose == 1:
     progbar = generic_utils.Progbar(target=steps)
@@ -444,10 +448,13 @@
     x, _, _ = model._standardize_user_data(x)
     x = training_utils.cast_if_floating_dtype(x)
 
+    if isinstance(x, list) and len(x) == 1:
+      x = x[0]
+
     if model._expects_training_arg:
-      batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
+      batch_outs = model.call(x, training=False)
     else:
-      batch_outs = model.call(x[0] if len(x) == 1 else x)
+      batch_outs = model.call(x)
     if not isinstance(batch_outs, list):
       batch_outs = [batch_outs]
 
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index bf5c7fd..3801300 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -366,7 +366,7 @@
     if scipy_sparse is None:
       return
 
-    with self.test_session():
+    with self.cached_session():
       test_inputs = [
           scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)
       ]
@@ -389,7 +389,7 @@
       model.evaluate(test_inputs, test_outputs, batch_size=2)
 
   def test_compile_with_sparse_placeholders(self):
-    with self.test_session():
+    with self.cached_session():
       input_layer = keras.layers.Input(shape=(10,), sparse=True)
       weights = variables_lib.Variable(
           np.ones((10, 1)).astype(np.float32), name='weights')
@@ -405,7 +405,7 @@
     val_a = np.random.random((10, 4))
     val_out = np.random.random((10, 4))
 
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(4,))
       layer = keras.layers.BatchNormalization(input_shape=(4,))
       b = layer(a)
@@ -441,7 +441,7 @@
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_compile_warning_for_loss_missing_output(self):
-    with self.test_session():
+    with self.cached_session():
       inp = keras.layers.Input(shape=(16,), name='input_a')
       out_1 = keras.layers.Dense(8, name='dense_1')(inp)
       out_2 = keras.layers.Dense(3, activation='softmax', name='dense_2')(out_1)
@@ -481,8 +481,8 @@
         num_hidden=10, num_classes=num_classes, input_dim=input_dim)
     model.compile(
         loss='categorical_crossentropy',
-        metrics=['acc'],
-        weighted_metrics=['mae'],
+        metrics=['acc', metrics_module.CategoricalAccuracy()],
+        weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
         optimizer=RMSPropOptimizer(learning_rate=learning_rate))
 
     np.random.seed(1337)
@@ -536,6 +536,25 @@
     self.assertLess(score[0], ref_score[0])
 
   @tf_test_util.run_in_graph_and_eager_modes
+  def test_sequential_model_fails_with_dict_inputs(self):
+    num_classes = 5
+    model = testing_utils.get_small_sequential_mlp(
+        num_hidden=10, num_classes=num_classes)
+    model.compile(
+        RMSPropOptimizer(learning_rate=0.001),
+        metrics=['acc'],
+        weighted_metrics=['mae'],
+        loss='categorical_crossentropy')
+
+    x = {'dense_input': np.random.random((10, 1))}
+    y = np.random.randint(num_classes, size=(10, 1))
+
+    with self.assertRaisesRegexp(
+        ValueError, 'Passing a dictionary input to a Sequential Model which '
+        'doesnt have FeatureLayer as the first layer is an error'):
+      model.fit(x, y, batch_size=5, epochs=1)
+
+  @tf_test_util.run_in_graph_and_eager_modes
   def test_sample_weights(self):
     num_classes = 5
     batch_size = 5
@@ -550,8 +569,8 @@
         num_hidden=10, num_classes=num_classes, input_dim=input_dim)
     model.compile(
         RMSPropOptimizer(learning_rate=learning_rate),
-        metrics=['acc'],
-        weighted_metrics=['mae'],
+        metrics=['acc', metrics_module.CategoricalAccuracy()],
+        weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
         loss='categorical_crossentropy')
 
     np.random.seed(43)
@@ -635,7 +654,7 @@
     timesteps = 3
     learning_rate = 0.001
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.TimeDistributed(
@@ -679,8 +698,8 @@
       model.compile(
           RMSPropOptimizer(learning_rate=learning_rate),
           loss='binary_crossentropy',
-          metrics=['acc'],
-          weighted_metrics=['mae'],
+          metrics=['acc', metrics_module.CategoricalAccuracy()],
+          weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
           sample_weight_mode='temporal')
 
       model.fit(
@@ -722,7 +741,7 @@
     timesteps = 3
     learning_rate = 0.001
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.TimeDistributed(
@@ -791,7 +810,7 @@
     timesteps = 3
     learning_rate = 0.001
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.TimeDistributed(
@@ -835,7 +854,7 @@
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_masking_graph_sequential(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.array([[[1], [1]], [[0], [0]]])
       model = keras.models.Sequential()
       model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
@@ -849,7 +868,7 @@
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_masking_deferred_sequential(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.array([[[1], [1]], [[0], [0]]])
       model = keras.models.Sequential()
       model.add(keras.layers.Masking(mask_value=0))
@@ -863,7 +882,7 @@
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_masking_functional(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.array([[[1], [1]], [[0], [0]]])
       inputs = keras.layers.Input((2, 1))
       outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -893,7 +912,7 @@
       def compute_output_shape(self, input_shape):
         return input_shape
 
-    with self.test_session():
+    with self.cached_session():
       x = np.random.random((5, 3))
       inputs = keras.layers.Input((3,))
       masked = keras.layers.Masking(mask_value=0)(inputs)
@@ -905,7 +924,7 @@
       model.train_on_batch(x, y)
 
   def test_loss_masking(self):
-    with self.test_session():
+    with self.cached_session():
       weighted_loss = weighted_masked_objective(keras.losses.get('mae'))
       shape = (3, 4, 2)
       x = np.arange(24).reshape(shape)
@@ -926,12 +945,12 @@
 class LearningPhaseTest(test.TestCase):
 
   def test_empty_model_no_learning_phase(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       self.assertFalse(model.uses_learning_phase)
 
   def test_dropout_has_learning_phase(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_dim=3))
       model.add(keras.layers.Dropout(0.5))
@@ -942,7 +961,7 @@
 class TestDynamicTrainability(test.TestCase):
 
   def test_trainable_warning(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.random.random((5, 3))
       y = np.random.random((5, 2))
 
@@ -955,7 +974,7 @@
       self.assertRaises(Warning)
 
   def test_trainable_argument(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.random.random((5, 3))
       y = np.random.random((5, 2))
 
@@ -978,7 +997,7 @@
       self.assertAllClose(out, out_2)
 
   def test_layer_trainability_switch(self):
-    with self.test_session():
+    with self.cached_session():
       # with constructor argument, in Sequential
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, trainable=False, input_dim=1))
@@ -1008,7 +1027,7 @@
       self.assertListEqual(model.trainable_weights, [])
 
   def test_model_trainability_switch(self):
-    with self.test_session():
+    with self.cached_session():
       # a non-trainable model has no trainable weights
       x = keras.layers.Input(shape=(1,))
       y = keras.layers.Dense(2)(x)
@@ -1023,7 +1042,7 @@
       self.assertListEqual(model.trainable_weights, [])
 
   def test_nested_model_trainability(self):
-    with self.test_session():
+    with self.cached_session():
       # a Sequential inside a Model
       inner_model = keras.models.Sequential()
       inner_model.add(keras.layers.Dense(2, input_dim=1))
@@ -1102,7 +1121,7 @@
         y = arr_labels[start: end]
         yield x, y
 
-    with self.test_session():
+    with self.cached_session():
       x = keras.Input((2,))
       y = keras.layers.Dense(1)(x)
       fn_model = keras.models.Model(x, y)
@@ -1188,7 +1207,7 @@
         w = arr_sample_weights[start: end]
         yield x, y, w
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(1, input_shape=(2,)))
       model.compile(
@@ -1225,7 +1244,7 @@
       while 1:
         yield 0
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(1, input_shape=(2,)))
       model.compile(loss='mse', optimizer='sgd')
@@ -1283,7 +1302,7 @@
         w = arr_sample_weights[start: end]
         yield x, y, w
 
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(1, input_shape=(2,)))
       model.compile(loss='mse', optimizer='sgd')
@@ -1303,6 +1322,57 @@
                         workers=0,
                         use_multiprocessing=False)
 
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_generator_input_to_fit_eval_predict(self):
+    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+    def custom_generator():
+      while True:
+        yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+    inputs = keras.layers.Input(shape=(10,))
+    x = keras.layers.Dense(10, activation='relu')(inputs)
+    outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+    model = keras.Model(inputs, outputs)
+
+    model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+    model.fit(
+        custom_generator(),
+        steps_per_epoch=2,
+        validation_data=val_data,
+        epochs=2)
+    model.evaluate(custom_generator(), steps=2)
+    model.predict(custom_generator(), steps=2)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_sequence_input_to_fit_eval_predict(self):
+    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+    class CustomSequence(keras.utils.Sequence):
+
+      def __getitem__(self, idx):
+        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+      def __len__(self):
+        return 2
+
+    inputs = keras.layers.Input(shape=(10,))
+    x = keras.layers.Dense(10, activation='relu')(inputs)
+    outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+    model = keras.Model(inputs, outputs)
+
+    model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+    model.fit(CustomSequence(), validation_data=val_data, epochs=2)
+    model.evaluate(CustomSequence())
+    model.predict(CustomSequence())
+
+    with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
+      model.fit(CustomSequence(), y=np.ones([10, 1]))
+
+    with self.assertRaisesRegexp(ValueError,
+                                 '`sample_weight` argument is not supported'):
+      model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
+
 
 class TestTrainingUtils(test.TestCase):
 
@@ -1341,7 +1411,7 @@
 class TestTrainingWithDataTensors(test.TestCase):
 
   def test_training_and_eval_methods_on_symbolic_tensors_single_io(self):
-    with self.test_session():
+    with self.cached_session():
       x = keras.layers.Input(shape=(3,), name='input')
       y = keras.layers.Dense(4, name='dense')(x)
       model = keras.Model(x, y)
@@ -1381,7 +1451,7 @@
                 validation_data=(inputs, targets), validation_steps=2)
 
   def test_training_and_eval_methods_on_symbolic_tensors_multi_io(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.layers.Input(shape=(3,), name='input_a')
       b = keras.layers.Input(shape=(3,), name='input_b')
 
@@ -1482,7 +1552,7 @@
     by only passing them data for the placeholder inputs
     in the model.
     """
-    with self.test_session():
+    with self.cached_session():
       input_a_np = np.random.random((10, 3))
       input_b_np = np.random.random((10, 3))
 
@@ -1613,7 +1683,7 @@
       self.assertEqual(out.shape, (10 * 3, 4))
 
   def test_model_with_partial_loss(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.Input(shape=(3,), name='input_a')
       a_2 = keras.layers.Dense(4, name='dense_1')(a)
       dp = keras.layers.Dropout(0.5, name='dropout')
@@ -1654,7 +1724,7 @@
       _ = model.evaluate(input_a_np, [output_a_np])
 
   def test_model_with_external_loss(self):
-    with self.test_session():
+    with self.cached_session():
       # None loss, only regularization loss.
       a = keras.Input(shape=(3,), name='input_a')
       a_2 = keras.layers.Dense(4, name='dense_1',
@@ -1784,7 +1854,7 @@
       self.assertEqual(out[1].shape, (10 * 3, 4))
 
   def test_target_tensors(self):
-    with self.test_session():
+    with self.cached_session():
       # single-output, as list
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(4, input_shape=(4,), name='dense'))
@@ -1845,7 +1915,7 @@
                            sample_weight={'dense_a': np.random.random((10,))})
 
   def test_model_custom_target_tensors(self):
-    with self.test_session():
+    with self.cached_session():
       a = keras.Input(shape=(3,), name='input_a')
       b = keras.Input(shape=(3,), name='input_b')
 
@@ -2097,8 +2167,45 @@
                                  'you should specify the `steps` argument'):
       model.predict(dataset, verbose=0)
 
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_dataset_with_sample_weights(self):
+    model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+    optimizer = RMSPropOptimizer(learning_rate=0.001)
+    loss = 'mse'
+    metrics = ['mae', metrics_module.CategoricalAccuracy()]
+    model.compile(optimizer, loss, metrics=metrics)
+
+    inputs = np.zeros((10, 3), np.float32)
+    targets = np.zeros((10, 4), np.float32)
+    sample_weights = np.ones((10), np.float32)
+    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
+                                                      sample_weights))
+    dataset = dataset.repeat(100)
+    dataset = dataset.batch(10)
+
+    model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+    model.evaluate(dataset, steps=2, verbose=1)
+    model.predict(dataset, steps=2)
+    model.train_on_batch(dataset)
+    model.predict_on_batch(dataset)
+
+  @tf_test_util.run_in_graph_and_eager_modes
+  def test_dataset_with_sparse_labels(self):
+    model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+    optimizer = RMSPropOptimizer(learning_rate=0.001)
+    loss = 'sparse_categorical_crossentropy'
+    model.compile(optimizer, loss)
+
+    inputs = np.zeros((10, 3))
+    targets = np.random.randint(0, 4, size=10, dtype=np.int32)
+    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+    dataset = dataset.repeat(100)
+    dataset = dataset.batch(10)
+
+    model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+
   def test_dataset_input_shape_validation(self):
-    with self.test_session():
+    with self.cached_session():
       model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
       model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
 
@@ -2108,8 +2215,10 @@
       dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
       dataset = dataset.repeat(100)
 
-      with self.assertRaisesRegexp(ValueError,
-                                   r'expected (.*?) to have 2 dimensions'):
+      with self.assertRaisesRegexp(
+          ValueError,
+          r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
+      ):
         model.train_on_batch(dataset)
 
       # Wrong input shape
@@ -2275,7 +2384,7 @@
 
   @tf_test_util.run_in_graph_and_eager_modes
   def test_metrics_masking(self):
-    with self.test_session():
+    with self.cached_session():
       np.random.seed(1337)
       model = keras.models.Sequential()
       model.add(keras.layers.Masking(mask_value=0, input_shape=(2, 1)))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index f94697c..8e9fab8 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -22,18 +22,22 @@
 import math
 
 import numpy as np
+import six
 
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import losses
 from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.util import nest
 
 
 def _map_nested(data, func):
@@ -210,10 +214,11 @@
 def standardize_single_array(x):
   if x is None:
     return None
-  elif tensor_util.is_tensor(x):
-    return x
-  elif x.ndim == 1:
-    x = np.expand_dims(x, 1)
+  if x.shape is not None and len(x.shape) == 1:
+    if tensor_util.is_tensor(x):
+      return array_ops.expand_dims(x, axis=1)
+    else:
+      return np.expand_dims(x, 1)
   return x
 
 
@@ -245,7 +250,8 @@
       ValueError: in case of improperly formatted user-provided data.
   """
   if not names:
-    if data is not None and hasattr(data, '__len__') and len(data):
+    if (data is not None and hasattr(data, '__len__') and len(data) and
+        not isinstance(data, dict)):
       raise ValueError('Error when checking model ' + exception_prefix + ': '
                        'expected no data, but got:', data)
     return []
@@ -341,7 +347,7 @@
   Raises:
       ValueError: In case of invalid user-provided argument.
   """
-  if x_weight is None or len(x_weight) == 0:  # pylint: disable=g-explicit-length-test
+  if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0):  # pylint: disable=g-explicit-length-test
     return [None for _ in output_names]
   if len(output_names) == 1:
     if isinstance(x_weight, list) and len(x_weight) == 1:
@@ -675,7 +681,8 @@
           'Expected sample_weight with rank '
           'less than or equal to ' + str(len(y.shape)))
 
-    if y.shape[:sample_weight.ndim] != sample_weight.shape:
+    if (not tensor_util.is_tensor(sample_weight) and
+        y.shape[:sample_weight.ndim] != sample_weight.shape):
       raise ValueError(
           'Found a sample_weight array with shape ' + str(sample_weight.shape) +
           ' for an input with shape ' + str(y.shape) + '. '
@@ -717,6 +724,8 @@
 def has_tensors(ls):
   if isinstance(ls, (list, tuple)):
     return any(tensor_util.is_tensor(v) for v in ls)
+  if isinstance(ls, dict):
+    return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
   return tensor_util.is_tensor(ls)
 
 
@@ -777,7 +786,9 @@
                      'Received: %s' % (x, y))
   if sample_weight is not None:
     raise ValueError('`sample_weight` argument is not supported when input '
-                     '`x` is a dataset or a dataset iterator. '
+                     '`x` is a dataset or a dataset iterator. Instead, you'
+                     'can provide sample_weight as the third element  of your'
+                     'dataset, i.e. (inputs, targets, sample_weight). '
                      'Received: x=%s, sample_weight=%s' % (x, sample_weight))
   if validation_split is not None and validation_split != 0.0:
     raise ValueError(
@@ -786,6 +797,18 @@
         'Received: x=%s, validation_split=%f' % (x, validation_split))
 
 
+def check_generator_arguments(y=None, sample_weight=None):
+  """Validates arguments passed when using a generator."""
+  if y is not None:
+    raise ValueError('`y` argument is not supported when data is'
+                     'a generator or Sequence instance. Instead pass targets'
+                     ' as the second element of the generator.')
+  if sample_weight is not None:
+    raise ValueError('`sample_weight` argument is not supported when data is'
+                     'a generator or Sequence instance. Instead pass sample'
+                     ' weights as the third element of the generator.')
+
+
 def check_steps_argument(input_data, steps, steps_name):
   """Validates `steps` argument based on input data's type.
 
@@ -825,6 +848,12 @@
   return False
 
 
+def cast_single_tensor(x):
+  if tensor_util.is_tensor(x) and x.dtype.is_floating:
+    return math_ops.cast(x, dtype=K.floatx())
+  return x
+
+
 def cast_if_floating_dtype(x):
   """Casts the given data tensors to the default floating point type.
 
@@ -842,13 +871,7 @@
     raise RuntimeError(
         'Please provide tensors for casting, got: {x}'.format(x=x))
 
-  if isinstance(x, (list, tuple)):
-    return [
-        math_ops.cast(val, dtype=K.floatx())
-        if tensor_util.is_tensor(val) and val.dtype.is_floating else val
-        for val in x
-    ]
-  return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
+  return nest.map_structure(cast_single_tensor, x)
 
 
 def get_output_sample_weight_and_mode(skip_target_weighing_indices,
@@ -929,3 +952,103 @@
       sample_weights.append(weight)
       sample_weight_modes.append(mode)
   return sample_weights, sample_weight_modes
+
+
+# TODO(rohanj): This is a hack to get around not depending on feature_column and
+# create a cyclical dependency. Figure out a cleaner solution
+def is_feature_layer(layer):
+  """Returns whether `layer` is a FeatureLayer or not."""
+  return getattr(layer, '_is_feature_layer', False)
+
+
+class ModelInputs(object):
+  """Encapsulates model inputs.
+
+  Allows for transforming model inputs while keeping the same structure.
+  """
+
+  def __init__(self, inputs):
+    self._inputs = inputs
+    self._is_dict = isinstance(self._inputs, dict)
+    self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
+    self._flattened_inputs = []
+    self._input_names = []
+    if isinstance(self._inputs, dict):
+      for k in sorted(self._inputs.keys()):
+        self._flattened_inputs.append(self._inputs[k])
+        self._input_names.append(k)
+    else:
+      self._flattened_inputs = nest.flatten(self._inputs)
+      self._input_names = [
+          'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
+      ]
+    assert len(self._input_names) == len(self._flattened_inputs)
+
+  def get_input_names(self):
+    """Returns keys to name inputs by.
+
+    In case inputs provided were a list, tuple or single entry, we make up a
+    key 'input_%d'. For dictionary case, we return a sorted list of keys.
+    """
+    return self._input_names
+
+  def _get(self, return_single_as_list=False):
+    """Returns provided inputs, potentially transformed.
+
+    Inputs are returned in the same format they were provided i.e. lists
+    are returned as lists, single entries as single entries (unless
+    `return_single_as_list` is true), dictionaries as dictionaries.
+
+    Args:
+      return_single_as_list: Returns a list of size 1 for single entry case.
+    """
+    if self._is_dict:
+      return dict(zip(self._input_names, self._flattened_inputs))
+    if self._is_single_input and not return_single_as_list:
+      return self._flattened_inputs[0]
+    return self._flattened_inputs
+
+  def get_input_values(self):
+    """Returns input values passed in."""
+    if context.executing_eagerly():
+      for i in range(len(self._flattened_inputs)):
+        v = self._flattened_inputs[i]
+        if tensor_util.is_tensor(v):
+          v = cast_single_tensor(v)
+        else:
+          v = ops.convert_to_tensor(v, dtype=K.floatx())
+        self._flattened_inputs[i] = v
+    return self._get(return_single_as_list=False)
+
+  def get_symbolic_inputs(self, return_single_as_list=False):
+    """Returns inputs to be set as self.inputs for a model."""
+    for i in range(len(self._flattened_inputs)):
+      k = self._input_names[i]
+      v = self._flattened_inputs[i]
+      if context.executing_eagerly():
+        v = base_layer.DeferredTensor(
+            shape=(None for _ in v.shape), dtype=v.dtype)
+      else:
+        if isinstance(v, list):
+          v = np.asarray(v)
+          if v.ndim == 1:
+            v = np.expand_dims(v, 1)
+        if isinstance(v, (np.ndarray)):
+          # We fix the placeholder shape except the batch size.
+          # This is suboptimal, but it is the best we can do with the info
+          # we have. The user should call `model._set_inputs(placeholders)`
+          # to specify custom placeholders if the need arises.
+          shape = (None,) + v.shape[1:]
+          v = K.placeholder(shape=shape, name=k)
+      self._flattened_inputs[i] = v
+
+    return self._get(return_single_as_list)
+
+  def as_dict(self):
+    """An iterable over a dictionary version of inputs."""
+    for i in range(len(self._flattened_inputs)):
+      yield self._input_names[i], self._flattened_inputs[i]
+
+  def as_list(self):
+    """Returning the inputs as a list."""
+    return self._flattened_inputs
diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py
index 297a1ae..e777cb6 100644
--- a/tensorflow/python/keras/engine/training_utils_test.py
+++ b/tensorflow/python/keras/engine/training_utils_test.py
@@ -20,8 +20,11 @@
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import base_layer
 from tensorflow.python.keras.engine import training_utils
 from tensorflow.python.platform import test
 
@@ -146,5 +149,91 @@
     self.assertEquals(any_true, False)
 
 
+class ModelInputsTest(test.TestCase):
+
+  def test_single_thing(self):
+    a = np.ones(10)
+    model_inputs = training_utils.ModelInputs(a)
+    self.assertEquals(['input_1'], model_inputs.get_input_names())
+    vals = model_inputs.get_input_values()
+    self.assertAllEqual(np.ones(10), vals)
+    self.assertFalse(tensor_util.is_tensor(vals))
+    vals = model_inputs.get_symbolic_inputs()
+    self.assertTrue(tensor_util.is_tensor(vals))
+    vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+    self.assertEquals(1, len(vals))
+    self.assertTrue(tensor_util.is_tensor(vals[0]))
+
+  def test_single_thing_eager(self):
+    with context.eager_mode():
+      a = np.ones(10)
+      model_inputs = training_utils.ModelInputs(a)
+      self.assertEquals(['input_1'], model_inputs.get_input_names())
+      vals = model_inputs.get_input_values()
+      self.assertAllEqual(np.ones(10), vals)
+      self.assertTrue(tensor_util.is_tensor(vals))
+      vals = model_inputs.get_symbolic_inputs()
+      self.assertTrue(isinstance(vals, base_layer.DeferredTensor))
+      vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+      self.assertEquals(1, len(vals))
+      self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+
+  def test_list(self):
+    a = [np.ones(10), np.ones(20)]
+    model_inputs = training_utils.ModelInputs(a)
+    self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+    vals = model_inputs.get_input_values()
+    self.assertEqual(2, len(vals))
+    self.assertAllEqual(np.ones(10), vals[0])
+    self.assertAllEqual(np.ones(20), vals[1])
+    self.assertFalse(tensor_util.is_tensor(vals[0]))
+    self.assertFalse(tensor_util.is_tensor(vals[1]))
+    vals = model_inputs.get_symbolic_inputs()
+    self.assertTrue(tensor_util.is_tensor(vals[0]))
+    self.assertTrue(tensor_util.is_tensor(vals[1]))
+
+  def test_list_eager(self):
+    with context.eager_mode():
+      a = [np.ones(10), np.ones(20)]
+      model_inputs = training_utils.ModelInputs(a)
+      self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+      vals = model_inputs.get_input_values()
+      self.assertEqual(2, len(vals))
+      self.assertAllEqual(np.ones(10), vals[0])
+      self.assertAllEqual(np.ones(20), vals[1])
+      self.assertTrue(tensor_util.is_tensor(vals[0]))
+      self.assertTrue(tensor_util.is_tensor(vals[1]))
+      vals = model_inputs.get_symbolic_inputs()
+      self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+      self.assertTrue(isinstance(vals[1], base_layer.DeferredTensor))
+
+  def test_dict(self):
+    a = {'b': np.ones(10), 'a': np.ones(20)}
+    model_inputs = training_utils.ModelInputs(a)
+    self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+    vals = model_inputs.get_input_values()
+    self.assertAllEqual(np.ones(20), vals['a'])
+    self.assertAllEqual(np.ones(10), vals['b'])
+    self.assertFalse(tensor_util.is_tensor(vals['a']))
+    self.assertFalse(tensor_util.is_tensor(vals['b']))
+    vals = model_inputs.get_symbolic_inputs()
+    self.assertTrue(tensor_util.is_tensor(vals['a']))
+    self.assertTrue(tensor_util.is_tensor(vals['b']))
+
+  def test_dict_eager(self):
+    with context.eager_mode():
+      a = {'b': np.ones(10), 'a': np.ones(20)}
+      model_inputs = training_utils.ModelInputs(a)
+      self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+      vals = model_inputs.get_input_values()
+      self.assertAllEqual(np.ones(20), vals['a'])
+      self.assertAllEqual(np.ones(10), vals['b'])
+      self.assertTrue(tensor_util.is_tensor(vals['a']))
+      self.assertTrue(tensor_util.is_tensor(vals['b']))
+      vals = model_inputs.get_symbolic_inputs()
+      self.assertTrue(isinstance(vals['a'], base_layer.DeferredTensor))
+      self.assertTrue(isinstance(vals['b'], base_layer.DeferredTensor))
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index a57ac12..d00def0 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -64,7 +64,7 @@
       specifying the stride length of the convolution.
       Specifying any stride value != 1 is incompatible with specifying
       any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
+    padding: One of `"valid"`,  `"same"`, or `"causal"` (case-insensitive).
     data_format: A string, one of `channels_last` (default) or `channels_first`.
       The ordering of the dimensions in the inputs.
       `channels_last` corresponds to inputs with shape
@@ -126,6 +126,10 @@
         kernel_size, rank, 'kernel_size')
     self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
     self.padding = conv_utils.normalize_padding(padding)
+    if (self.padding == 'causal' and not isinstance(self,
+                                                    (Conv1D, SeparableConv1D))):
+      raise ValueError('Causal padding is only supported for `Conv1D`'
+                       'and ``SeparableConv1D`.')
     self.data_format = conv_utils.normalize_data_format(data_format)
     self.dilation_rate = conv_utils.normalize_tuple(
         dilation_rate, rank, 'dilation_rate')
@@ -172,12 +176,16 @@
       self.bias = None
     self.input_spec = InputSpec(ndim=self.rank + 2,
                                 axes={channel_axis: input_dim})
+    if self.padding == 'causal':
+      op_padding = 'valid'
+    else:
+      op_padding = self.padding
     self._convolution_op = nn_ops.Convolution(
         input_shape,
         filter_shape=self.kernel.get_shape(),
         dilation_rate=self.dilation_rate,
         strides=self.strides,
-        padding=self.padding.upper(),
+        padding=op_padding.upper(),
         data_format=conv_utils.convert_data_format(self.data_format,
                                                    self.rank + 2))
     self.built = True
@@ -264,6 +272,15 @@
     base_config = super(Conv, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
 
+  def _compute_causal_padding(self):
+    """Calculates padding for 'causal' option for 1-d conv layers."""
+    left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
+    if self.data_format == 'channels_last':
+      causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
+    else:
+      causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
+    return causal_padding
+
 
 @tf_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
 class Conv1D(Conv):
@@ -361,6 +378,11 @@
         bias_constraint=constraints.get(bias_constraint),
         **kwargs)
 
+  def call(self, inputs):
+    if self.padding == 'causal':
+      inputs = array_ops.pad(inputs, self._compute_causal_padding())
+    return super(Conv1D, self).call(inputs)
+
 
 @tf_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
 class Conv2D(Conv):
@@ -1261,31 +1283,44 @@
 
   def get_config(self):
     config = {
-        'filters': self.filters,
-        'kernel_size': self.kernel_size,
-        'strides': self.strides,
-        'padding': self.padding,
-        'data_format': self.data_format,
-        'dilation_rate': self.dilation_rate,
-        'activation': activations.serialize(self.activation),
-        'use_bias': self.use_bias,
+        'filters':
+            self.filters,
+        'kernel_size':
+            self.kernel_size,
+        'strides':
+            self.strides,
+        'padding':
+            self.padding,
+        'data_format':
+            self.data_format,
+        'depth_multiplier':
+            self.depth_multiplier,
+        'dilation_rate':
+            self.dilation_rate,
+        'activation':
+            activations.serialize(self.activation),
+        'use_bias':
+            self.use_bias,
         'depthwise_initializer':
             initializers.serialize(self.depthwise_initializer),
         'pointwise_initializer':
             initializers.serialize(self.pointwise_initializer),
-        'bias_initializer': initializers.serialize(self.bias_initializer),
+        'bias_initializer':
+            initializers.serialize(self.bias_initializer),
         'depthwise_regularizer':
             regularizers.serialize(self.depthwise_regularizer),
         'pointwise_regularizer':
             regularizers.serialize(self.pointwise_regularizer),
-        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+        'bias_regularizer':
+            regularizers.serialize(self.bias_regularizer),
         'activity_regularizer':
             regularizers.serialize(self.activity_regularizer),
         'depthwise_constraint':
             constraints.serialize(self.depthwise_constraint),
         'pointwise_constraint':
             constraints.serialize(self.pointwise_constraint),
-        'bias_constraint': constraints.serialize(self.bias_constraint)
+        'bias_constraint':
+            constraints.serialize(self.bias_constraint)
     }
     base_config = super(SeparableConv, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
@@ -1311,7 +1346,7 @@
       of the convolution.
       Specifying any `stride` value != 1 is incompatible with specifying
       any `dilation_rate` value != 1.
-    padding: One of `"valid"` or `"same"` (case-insensitive).
+    padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).
     data_format: A string, one of `channels_last` (default) or `channels_first`.
       The ordering of the dimensions in the inputs.
       `channels_last` corresponds to inputs with shape
@@ -1397,6 +1432,8 @@
         **kwargs)
 
   def call(self, inputs):
+    if self.padding == 'causal':
+      inputs = array_ops.pad(inputs, self._compute_causal_padding())
     if self.data_format == 'channels_last':
       strides = (1,) + self.strides * 2 + (1,)
       spatial_start_dim = 1
@@ -1411,12 +1448,16 @@
     pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
     dilation_rate = (1,) + self.dilation_rate
 
+    if self.padding == 'causal':
+      op_padding = 'valid'
+    else:
+      op_padding = self.padding
     outputs = nn.separable_conv2d(
         inputs,
         depthwise_kernel,
         pointwise_kernel,
         strides=strides,
-        padding=self.padding.upper(),
+        padding=op_padding.upper(),
         rate=dilation_rate,
         data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
 
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index f904744..2d3d38a 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -52,7 +52,7 @@
         'kernel_size': 3,
     }
 
-    self._run_test(kwargs, 'padding', ['valid', 'same'])
+    self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
     self._run_test(kwargs, 'strides', [2])
     self._run_test(kwargs, 'dilation_rate', [2])
 
@@ -329,7 +329,7 @@
         'kernel_size': 3,
     }
 
-    self._run_test(kwargs, 'padding', ['valid', 'same'])
+    self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
     self._run_test(kwargs, 'strides', [2])
     self._run_test(kwargs, 'dilation_rate', [2])
     self._run_test(kwargs, 'depth_multiplier', [2])
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index afef997..9988c9f 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -87,7 +87,7 @@
     embedding_dim = 4
     units = 2
     layer_class = keras.layers.GRU
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.Embedding(
@@ -146,7 +146,7 @@
   def test_regularizers_GRU(self):
     embedding_dim = 4
     layer_class = keras.layers.GRU
-    with self.test_session():
+    with self.cached_session():
       layer = layer_class(
           5,
           return_sequences=False,
@@ -166,7 +166,7 @@
   def test_constraints_GRU(self):
     embedding_dim = 4
     layer_class = keras.layers.GRU
-    with self.test_session():
+    with self.cached_session():
       k_constraint = keras.constraints.max_norm(0.01)
       r_constraint = keras.constraints.max_norm(0.01)
       b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@
   @tf_test_util.run_in_graph_and_eager_modes
   def test_with_masking_layer_GRU(self):
     layer_class = keras.layers.GRU
-    with self.test_session():
+    with self.cached_session():
       inputs = np.random.random((2, 3, 4))
       targets = np.abs(np.random.random((2, 3, 5)))
       targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 9802820..f536915 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -102,7 +102,7 @@
     embedding_dim = 4
     units = 2
     layer_class = keras.layers.LSTM
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.Embedding(
@@ -161,7 +161,7 @@
   def test_regularizers_LSTM(self):
     embedding_dim = 4
     layer_class = keras.layers.LSTM
-    with self.test_session():
+    with self.cached_session():
       layer = layer_class(
           5,
           return_sequences=False,
@@ -180,7 +180,7 @@
   def test_constraints_LSTM(self):
     embedding_dim = 4
     layer_class = keras.layers.LSTM
-    with self.test_session():
+    with self.cached_session():
       k_constraint = keras.constraints.max_norm(0.01)
       r_constraint = keras.constraints.max_norm(0.01)
       b_constraint = keras.constraints.max_norm(0.01)
@@ -200,7 +200,7 @@
   @tf_test_util.run_in_graph_and_eager_modes
   def test_with_masking_layer_LSTM(self):
     layer_class = keras.layers.LSTM
-    with self.test_session():
+    with self.cached_session():
       inputs = np.random.random((2, 3, 4))
       targets = np.abs(np.random.random((2, 3, 5)))
       targets /= targets.sum(axis=-1, keepdims=True)
@@ -225,7 +225,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       # Test with Keras tensor
       inputs = keras.Input((timesteps, embedding_dim))
       initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -252,7 +252,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       # Test with non-Keras tensor
       inputs = keras.Input((timesteps, embedding_dim))
       initial_state = [keras.backend.random_normal_variable(
@@ -275,7 +275,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       layer = keras.layers.LSTM(units, stateful=True)
       layer.build((num_samples, timesteps, embedding_dim))
       layer.reset_states()
@@ -306,7 +306,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.Input((timesteps, embedding_dim))
       _ = keras.layers.Masking()(inputs)
       initial_state = [keras.Input((units,)) for _ in range(num_states)]
@@ -329,7 +329,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
       layer = keras.layers.LSTM(units, return_state=True, stateful=True)
       outputs = layer(inputs)
@@ -347,7 +347,7 @@
     units = 3
     num_samples = 2
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
       layer = keras.layers.LSTM(units, return_state=True, return_sequences=True)
       outputs = layer(inputs)
@@ -366,7 +366,7 @@
     num_states = 2
     layer_class = keras.layers.LSTM
 
-    with self.test_session():
+    with self.cached_session():
       # Test with Keras tensor
       main_inputs = keras.Input((timesteps, embedding_dim))
       initial_state = [keras.Input((units,)) for _ in range(num_states)]
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index a3861e4..b9e9009 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -530,7 +530,9 @@
         y_np_2 = model.predict(x_np)
         self.assertAllClose(y_np, y_np_2, atol=1e-4)
 
-  def test_stacked_rnn_dropout(self):
+  def DISABLED_test_stacked_rnn_dropout(self):
+    # Temporarily disabled test due an occasional Grappler segfault.
+    # See b/115523414
     cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
              keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
     layer = keras.layers.RNN(cells)
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 1429537..2f2295a 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -87,7 +87,7 @@
     embedding_dim = 4
     units = 2
     layer_class = keras.layers.SimpleRNN
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(
           keras.layers.Embedding(
@@ -146,7 +146,7 @@
   def test_regularizers_SimpleRNN(self):
     embedding_dim = 4
     layer_class = keras.layers.SimpleRNN
-    with self.test_session():
+    with self.cached_session():
       layer = layer_class(
           5,
           return_sequences=False,
@@ -166,7 +166,7 @@
   def test_constraints_SimpleRNN(self):
     embedding_dim = 4
     layer_class = keras.layers.SimpleRNN
-    with self.test_session():
+    with self.cached_session():
       k_constraint = keras.constraints.max_norm(0.01)
       r_constraint = keras.constraints.max_norm(0.01)
       b_constraint = keras.constraints.max_norm(0.01)
@@ -186,7 +186,7 @@
   @tf_test_util.run_in_graph_and_eager_modes
   def test_with_masking_layer_SimpleRNN(self):
     layer_class = keras.layers.SimpleRNN
-    with self.test_session():
+    with self.cached_session():
       inputs = np.random.random((2, 3, 4))
       targets = np.abs(np.random.random((2, 3, 5)))
       targets /= targets.sum(axis=-1, keepdims=True)
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 81c760b..473d8cd 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -22,7 +22,10 @@
 from abc import ABCMeta
 from abc import abstractmethod
 
+import functools
+import sys
 import types
+import weakref
 import six
 
 from tensorflow.python.eager import context
@@ -137,6 +140,21 @@
   return tf_decorator.make_decorator(result_fn, decorated)
 
 
+def weakmethod(method):
+  """Creates a weak reference to the bound method."""
+
+  cls = method.im_class
+  func = method.im_func
+  instance_ref = weakref.ref(method.im_self)
+
+  @functools.wraps(method)
+  def inner(*args, **kwargs):
+    return func.__get__(instance_ref(), cls)(*args, **kwargs)
+
+  del method
+  return inner
+
+
 def safe_div(numerator, denominator):
   """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
 
@@ -318,14 +336,27 @@
 
   def __new__(cls, *args, **kwargs):
     obj = super(Metric, cls).__new__(cls)
-    # TODO(psv): Fix reference cycle issue here.
 
-    # Converting update_state_fn() into a graph function, so that
-    # we can return a single op that performs all of the variable updates.
-    defuned_update_state_fn = function.defun(obj.update_state)
-    obj.update_state = types.MethodType(
-        update_state_wrapper(defuned_update_state_fn), obj)
-    obj.result = types.MethodType(result_wrapper(obj.result), obj)
+    if sys.version_info < (3,):
+      # Wrap methods in `weakmethod` function to remove binding and create a
+      # weak reference. This is to remove reference cycle that is created here.
+      # This is not an issue in python versions > 3.
+      if context.executing_eagerly():
+        update_state = weakmethod(obj.update_state)
+      else:
+        update_state = function.defun(obj.update_state)
+      obj.update_state = weakmethod(
+          types.MethodType(update_state_wrapper(update_state), obj))
+      result = weakmethod(obj.result)
+      obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
+    else:
+      # Converting update_state_fn() into a graph function, so that
+      # we can return a single op that performs all of the variable updates.
+      defuned_update_state_fn = function.defun(obj.update_state)
+      obj.update_state = types.MethodType(
+          update_state_wrapper(defuned_update_state_fn), obj)
+      obj.result = types.MethodType(result_wrapper(obj.result), obj)
+
     return obj
 
   def __call__(self, *args, **kwargs):
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 779c08c..4195ea1 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -212,7 +212,7 @@
       self.assertAllClose(
           val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
 
-  @test_util.run_in_graph_and_eager_modes
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def test_mean(self):
     m = metrics.Mean(name='my_mean')
 
@@ -394,7 +394,7 @@
     self.assertTrue(acc_obj.stateful)
     self.assertEqual(len(acc_obj.variables), 2)
     self.assertEqual(acc_obj.dtype, dtypes.float32)
-    self.evaluate(variables.global_variables_initializer())
+    self.evaluate(variables.variables_initializer(acc_obj.variables))
 
     # verify that correct value is returned
     update_op = acc_obj.update_state([[0, 0, 1], [0, 1, 0]],
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 71c1987..3a1b000 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -463,7 +463,7 @@
     num_samples = 10
     input_dim = 50
 
-    with self.test_session():
+    with self.cached_session():
       model = SimpleTestModel(num_classes=num_classes,
                               use_dp=True,
                               use_bn=True)
@@ -481,7 +481,7 @@
     num_samples = 10
     input_dim = 50
 
-    with self.test_session():
+    with self.cached_session():
       model = MultiIOTestModel(num_classes=num_classes,
                                use_dp=True,
                                use_bn=True)
@@ -501,7 +501,7 @@
     num_samples = 10
     input_dim = 50
 
-    with self.test_session():
+    with self.cached_session():
       model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True)
       model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
 
@@ -521,7 +521,7 @@
     num_samples = 1000
     input_dim = 50
 
-    with self.test_session():
+    with self.cached_session():
       model = MultiIOTestModel(num_classes=num_classes,
                                use_dp=True,
                                use_bn=True)
@@ -610,7 +610,7 @@
       def call(self, x):
         return self.bn(self.fc(x))
 
-    with self.test_session():
+    with self.cached_session():
       model = TestModel1()
 
       x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -631,7 +631,7 @@
       def call(self, x):
         return self.bn(self.fc(x))
 
-    with self.test_session():
+    with self.cached_session():
       model = TestModel2()
 
       x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -655,7 +655,7 @@
       def call(self, x):
         return self.bn(self.fc(x))
 
-    with self.test_session():
+    with self.cached_session():
       model = TestModel3()
 
       x = array_ops.ones(shape=[100, 784], dtype='float32')
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index c3b7301..41c5e3c 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -414,10 +414,10 @@
       this argument must be set to `True` (default `False`). To restore the
       original model, use the function
       `in_place_subclassed_model_state_restoration(model)`.
-    optimizer_iterations: An iterations variable to pass to the optimizer if
-      the model uses a TFOptimizer, and if the clone is compiled. This is used
-      when a Keras model is cloned into an Estimator model function, because
-      Estimators create their own global step variable.
+    optimizer_iterations: An iterations variable that will be incremented by the
+      optimizer if the clone is compiled. This argument is used when a Keras
+      model is cloned into an Estimator model function, because Estimators
+      create their own global step variable.
 
   Returns:
     Clone of the model.
@@ -444,6 +444,8 @@
     clone = model
     _in_place_subclassed_model_reset(clone)
     if input_tensors is not None:
+      if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
+        input_tensors = input_tensors[0]
       clone._set_inputs(input_tensors)
 
   # Compile/Build model
@@ -458,6 +460,8 @@
     else:
       optimizer_config = model.optimizer.get_config()
       optimizer = model.optimizer.__class__.from_config(optimizer_config)
+      if optimizer_iterations is not None:
+        optimizer.iterations = optimizer_iterations
 
     clone.compile(
         optimizer,
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 1d0f56f..c550cae 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -25,7 +25,9 @@
 
 from tensorflow.python import keras
 from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
 from tensorflow.python.keras import metrics
 from tensorflow.python.keras import models
 from tensorflow.python.ops import random_ops
@@ -51,7 +53,7 @@
 class TestModelCloning(test.TestCase):
 
   def test_clone_sequential_model(self):
-    with self.test_session():
+    with self.cached_session():
       val_a = np.random.random((10, 4))
       val_out = np.random.random((10, 4))
 
@@ -64,7 +66,7 @@
     # Everything should work in a new session.
     keras.backend.clear_session()
 
-    with self.test_session():
+    with self.cached_session():
       # With placeholder creation
       new_model = keras.models.clone_model(model)
       # update ops from batch norm needs to be included
@@ -89,7 +91,7 @@
       new_model.train_on_batch(None, val_out)
 
   def test_clone_functional_model(self):
-    with self.test_session():
+    with self.cached_session():
       val_a = np.random.random((10, 4))
       val_b = np.random.random((10, 4))
       val_out = np.random.random((10, 4))
@@ -110,7 +112,7 @@
     # Everything should work in a new session.
     keras.backend.clear_session()
 
-    with self.test_session():
+    with self.cached_session():
       # With placeholder creation
       new_model = keras.models.clone_model(model)
       self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
@@ -137,7 +139,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_clone_functional_model_with_masking(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.array([[[1], [1]], [[0], [0]]])
       inputs = keras.Input((2, 1))
       outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -238,7 +240,7 @@
 class TestCloneAndBuildModel(test.TestCase):
 
   def test_clone_and_build_non_compiled_model(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.random((10, 4))
       out = np.random.random((10, 4))
 
@@ -251,7 +253,7 @@
     # Everything should work in a new session.
     keras.backend.clear_session()
 
-    with self.test_session():
+    with self.cached_session():
       # With placeholder creation
       new_model = models.clone_and_build_model(model, compile_clone=True)
       with self.assertRaisesRegexp(RuntimeError, 'must compile'):
@@ -289,7 +291,7 @@
     # Everything should work in a new session.
     keras.backend.clear_session()
 
-    with self.test_session():
+    with self.cached_session():
       # With placeholder creation
       new_model = models.clone_and_build_model(
           model, compile_clone=True, in_place_reset=is_subclassed)
@@ -316,7 +318,7 @@
       new_model.evaluate(inp, out)
 
   def test_clone_and_build_compiled_sequential_model(self):
-    with self.test_session():
+    with self.cached_session():
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(4, input_shape=(4,)))
       model.add(keras.layers.BatchNormalization())
@@ -328,7 +330,7 @@
     self._clone_and_build_test_helper(model)
 
   def test_clone_and_build_functional_model(self):
-    with self.test_session():
+    with self.cached_session():
       input_a = keras.Input(shape=(4,))
       dense_1 = keras.layers.Dense(4,)
       dense_2 = keras.layers.Dense(4,)
@@ -358,12 +360,42 @@
         out = self.layer2(out)
         return out
 
-    with self.test_session():
+    with self.cached_session():
       model = SubclassedModel()
       model.compile('rmsprop', 'mse',
                     metrics=['acc', metrics.categorical_accuracy])
     self._clone_and_build_test_helper(model, True)
 
+  def assert_optimizer_iterations_increases(self, optimizer):
+    with self.cached_session():
+      input_a = keras.Input(shape=(4,))
+      dense_1 = keras.layers.Dense(4,)
+      dense_2 = keras.layers.Dense(4,)
+
+      x_a = dense_1(input_a)
+      x_a = keras.layers.Dropout(0.5)(x_a)
+      x_a = keras.layers.BatchNormalization()(x_a)
+      x_a = dense_2(x_a)
+      model = keras.models.Model(input_a, x_a)
+      model.compile(optimizer, 'mse',
+                    metrics=['acc', metrics.categorical_accuracy])
+
+      global_step = keras.backend.variable(123, dtype=dtypes.int64)
+      clone_model = models.clone_and_build_model(
+          model, compile_clone=True, optimizer_iterations=global_step)
+
+      inp = np.random.random((10, 4))
+      out = np.random.random((10, 4))
+      clone_model.train_on_batch(inp, out)
+
+      self.assertEqual(K.eval(global_step), 124)
+
+  def test_replace_tf_optimizer_iterations_variable(self):
+    self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
+
+  def test_replace_keras_optimizer_iterations_variable(self):
+    self.assert_optimizer_iterations_increases('adam')
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 9a68fc0..8d74934 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -85,23 +85,23 @@
 class KerasOptimizersTest(test.TestCase):
 
   def test_sgd(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.SGD(lr=0.01,
                                            momentum=0.9,
                                            nesterov=True))
 
   def test_rmsprop(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.RMSprop())
       _test_optimizer(keras.optimizers.RMSprop(decay=1e-3))
 
   def test_adagrad(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.Adagrad())
       _test_optimizer(keras.optimizers.Adagrad(decay=1e-3))
 
   def test_adadelta(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.Adadelta(), target=0.6)
       # Accuracy seems dependent on the initialization. Even adding tf.Print
       # nodes in the graph seemed to affect the initialization seed, and hence
@@ -109,28 +109,28 @@
       _test_optimizer(keras.optimizers.Adadelta(decay=1e-3), target=0.4)
 
   def test_adam(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.Adam())
       _test_optimizer(keras.optimizers.Adam(decay=1e-3))
       _test_optimizer(keras.optimizers.Adam(amsgrad=True))
 
   def test_adamax(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.Adamax())
       _test_optimizer(keras.optimizers.Adamax(decay=1e-3))
 
   def test_nadam(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.Nadam())
 
   def test_clipnorm(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.SGD(lr=0.01,
                                            momentum=0.9,
                                            clipnorm=0.5))
 
   def test_clipvalue(self):
-    with self.test_session():
+    with self.cached_session():
       _test_optimizer(keras.optimizers.SGD(lr=0.01,
                                            momentum=0.9,
                                            clipvalue=0.5))
@@ -158,7 +158,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_tfoptimizer_iterations(self):
-    with self.test_session():
+    with self.cached_session():
       optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 58405c5..501b50b 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -29,7 +29,8 @@
 def get_test_data(train_samples,
                   test_samples,
                   input_shape,
-                  num_classes):
+                  num_classes,
+                  random_seed=None):
   """Generates test data to train a model on.
 
   Arguments:
@@ -37,10 +38,13 @@
     test_samples: Integer, how many test samples to generate.
     input_shape: Tuple of integers, shape of the inputs.
     num_classes: Integer, number of classes for the data and targets.
+    random_seed: Integer, random seed used by numpy to generate data.
 
   Returns:
     A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
   """
+  if random_seed is not None:
+    np.random.seed(random_seed)
   num_sample = train_samples + test_samples
   templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
   y = np.random.randint(0, num_classes, size=(num_sample,))
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 3a176c3..8ebca14 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -93,7 +93,7 @@
   Arguments:
       input_length: integer.
       filter_size: integer.
-      padding: one of "same", "valid", "full".
+      padding: one of "same", "valid", "full", "causal"
       stride: integer.
       dilation: dilation rate, integer.
 
@@ -102,9 +102,9 @@
   """
   if input_length is None:
     return None
-  assert padding in {'same', 'valid', 'full'}
+  assert padding in {'same', 'valid', 'full', 'causal'}
   dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
-  if padding == 'same':
+  if padding in ['same', 'causal']:
     output_length = input_length
   elif padding == 'valid':
     output_length = input_length - dilated_filter_size + 1
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index c1ee34a..b736daa 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -40,6 +40,7 @@
 from six.moves.urllib.request import urlopen
 
 from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -93,6 +94,11 @@
   from six.moves.urllib.request import urlretrieve
 
 
+def is_generator_or_sequence(x):
+  """Check if `x` is a Keras generator type."""
+  return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
+
+
 def _extract_archive(file_path, path='.', archive_format='auto'):
   """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
 
@@ -494,6 +500,7 @@
     raise NotImplementedError
 
 
+@tf_export('keras.utils.OrderedEnqueuer')
 class OrderedEnqueuer(SequenceEnqueuer):
   """Builds a Enqueuer from a Sequence.
 
@@ -550,7 +557,7 @@
       self.executor_fn = lambda seqs: multiprocessing.Pool(  # pylint: disable=g-long-lambda
           workers, initializer=init_pool, initargs=(seqs,))
     else:
-       # We do not need the init since it's threads.
+      # We do not need the init since it's threads.
       self.executor_fn = lambda _: ThreadPool(workers)
     self.workers = workers
     self.queue = queue.Queue(max_queue_size)
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 1f28c59..158a9a5 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -26,6 +26,7 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
+@tf_export('keras.utils.get_source_inputs')
 def get_source_inputs(tensor, layer=None, node_index=None):
   """Returns the list of input tensors necessary to compute `tensor`.
 
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3026c77..6bba99b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -286,7 +286,10 @@
     srcs = ["decode_csv_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python/eager:context",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:errors",
+        "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:parsing_ops",
     ],
 )
@@ -622,6 +625,7 @@
         "//tensorflow/python:linalg_ops",
         "//tensorflow/python:math_ops",
     ],
+    tags = ["notap"],
 )
 
 cuda_py_test(
@@ -779,6 +783,7 @@
     size = "small",
     srcs = ["regex_full_match_op_test.py"],
     additional_deps = [
+        "@absl_py//absl/testing:parameterized",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
         "//tensorflow/python:dtypes",
@@ -1009,6 +1014,7 @@
     size = "small",
     srcs = ["substr_op_test.py"],
     additional_deps = [
+        "@absl_py//absl/testing:parameterized",
         "//third_party/py/numpy",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:errors",
@@ -1634,6 +1640,7 @@
     srcs = ["functional_ops_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework",
diff --git a/tensorflow/python/kernel_tests/accumulate_n_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py
index b793906..0bc5268 100644
--- a/tensorflow/python/kernel_tests/accumulate_n_test.py
+++ b/tensorflow/python/kernel_tests/accumulate_n_test.py
@@ -76,7 +76,7 @@
   # Putting them here so that everything that exercises AccumulateNV2 is in
   # one place and the default build runs all unit tests.
   def testSimple(self):
-    with self.test_session():
+    with self.cached_session():
       random_arrays = [
           np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
       ]
@@ -91,27 +91,27 @@
       self.assertAllClose(np_val, tf_val.eval())
 
   def testZeroArgs(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         tf_val = math_ops.accumulate_n([])
         tf_val.eval()
 
   def testWrongShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         a = variables.Variable(0.2)
         b = variables.Variable(0.1)
         math_ops.accumulate_n([a, b], shape=[2, 2])  # Should be shape=[]
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         a = variables.Variable(np.array([0.1, 0.2]))
         b = variables.Variable(np.array([[0.3], [0.4]]))
         math_ops.accumulate_n([a, b])
 
   def testWrongType(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         a = variables.Variable(0.2, dtype=np.float32)
         b = variables.Variable(0.1, dtype=np.float32)
@@ -119,7 +119,7 @@
 
   def testWrongTypeOneInput(self):
     # Scenario that used to trigger a bug, even when testWrongType() worked
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         a = variables.Variable(0.2, dtype=np.float32)
         math_ops.accumulate_n([a], tensor_dtype=np.int32)
diff --git a/tensorflow/python/kernel_tests/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py
index 5e0d87c..d267e49 100644
--- a/tensorflow/python/kernel_tests/ackermann_test.py
+++ b/tensorflow/python/kernel_tests/ackermann_test.py
@@ -34,7 +34,7 @@
     self.assertEqual(len(ackermann.OP_LIST.op), 1)
     self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
 
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
 
 
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 1202c46..127d14c 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -104,20 +104,20 @@
     self._testDim(np.int64)
 
   def testEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       for op in math_ops.argmin, math_ops.argmax:
         with self.assertRaisesOpError(
             r"Reduction axis 0 is empty in shape \[0\]"):
           op([], 0).eval()
 
   def testDefaultAxis(self):
-    with self.test_session():
+    with self.cached_session():
       for op in math_ops.argmin, math_ops.argmax:
         ans = op([1]).eval()
         self.assertAllEqual(ans, 0)
 
   def testOutputEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       for op in math_ops.argmin, math_ops.argmax:
         ret = op(array_ops.zeros(shape=[1, 0, 2]), axis=-1).eval()
         self.assertEqual(ret.shape, (1, 0))
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index a164682..573bb86 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -50,7 +50,7 @@
   def testNonBatchMatrix(self):
     matrix = [[1, 2, 3], [4, 5, 6]]  # Shape (2, 3)
     expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
-    with self.test_session():
+    with self.cached_session():
       transposed = array_ops.matrix_transpose(matrix)
       self.assertEqual((3, 2), transposed.get_shape())
       self.assertAllEqual(expected_transposed, transposed.eval())
@@ -58,7 +58,7 @@
   def testConjugate(self):
     m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
     expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
-    with self.test_session():
+    with self.cached_session():
       matrix = ops.convert_to_tensor(m)
       transposed = array_ops.matrix_transpose(matrix, conjugate=True)
       self.assertEqual((3, 2), transposed.get_shape())
@@ -71,7 +71,7 @@
     matrix_1_t = [[11, 44], [22, 55], [33, 66]]
     batch_matrix = [matrix_0, matrix_1]  # Shape (2, 2, 3)
     expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
-    with self.test_session():
+    with self.cached_session():
       transposed = array_ops.matrix_transpose(batch_matrix)
       self.assertEqual((2, 3, 2), transposed.get_shape())
       self.assertAllEqual(expected_transposed, transposed.eval())
@@ -79,7 +79,7 @@
   def testNonBatchMatrixDynamicallyDefined(self):
     matrix = [[1, 2, 3], [4, 5, 6]]  # Shape (2, 3)
     expected_transposed = [[1, 4], [2, 5], [3, 6]]  # Shape (3, 2)
-    with self.test_session():
+    with self.cached_session():
       matrix_ph = array_ops.placeholder(dtypes.int32)
       transposed = array_ops.matrix_transpose(matrix_ph)
       self.assertAllEqual(
@@ -94,7 +94,7 @@
     matrix_1_t = [[11, 44], [22, 55], [33, 66]]
     batch_matrix = [matrix_0, matrix_1]  # Shape (2, 2, 3)
     expected_transposed = [matrix_0_t, matrix_1_t]  # Shape (2, 3, 2)
-    with self.test_session():
+    with self.cached_session():
       batch_matrix_ph = array_ops.placeholder(dtypes.int32)
       transposed = array_ops.matrix_transpose(batch_matrix_ph)
       self.assertAllEqual(
@@ -105,7 +105,7 @@
 
   def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
     vector = [1, 2, 3]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "should be a "):
         array_ops.matrix_transpose(vector)
 
@@ -129,7 +129,7 @@
       masked_arr = arr[:, mask]
     elif axis == 2:
       masked_arr = arr[:, :, mask]
-    with self.test_session():
+    with self.cached_session():
       masked_tensor = array_ops.boolean_mask(arr, mask, axis=axis)
 
       # Leading dimension size of masked_tensor is always unknown until runtime
@@ -176,7 +176,7 @@
     numpy_result = arr[mask]
     tf_result = array_ops.boolean_mask(arr, mask)
     self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(numpy_result, tf_result.eval())
 
   def testEmptyInput1D(self):
@@ -185,7 +185,7 @@
     numpy_result = arr[mask]
     tf_result = array_ops.boolean_mask(arr, mask)
     self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(numpy_result, tf_result.eval())
 
   def testEmptyOutput(self):
@@ -199,7 +199,7 @@
   def testWorksWithDimensionsEqualToNoneDuringGraphBuild(self):
     # The rank of the mask tensor must be specified. This is explained
     # in the docstring as well.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ph_tensor = array_ops.placeholder(dtypes.int32, shape=None)
       ph_mask = array_ops.placeholder(dtypes.bool, shape=[None])
 
@@ -217,7 +217,7 @@
   def testMaskDimensionsSetToNoneRaises(self):
     # The rank of the mask tensor must be specified. This is explained
     # in the docstring as well.
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.int32, shape=[None, 2])
       mask = array_ops.placeholder(dtypes.bool, shape=None)
       with self.assertRaisesRegexp(ValueError, "dimensions must be specified"):
@@ -226,21 +226,21 @@
   def testMaskHasMoreDimsThanTensorRaises(self):
     mask = [[True, True], [False, False]]
     tensor = [1, 2, 3, 4]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "incompatible"):
         array_ops.boolean_mask(tensor, mask).eval()
 
   def testMaskIsScalarRaises(self):
     mask = True
     tensor = 1
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "mask.*scalar"):
         array_ops.boolean_mask(tensor, mask).eval()
 
   def testMaskShapeDifferentThanFirstPartOfTensorShapeRaises(self):
     mask = [True, True, True]
     tensor = [[1, 2], [3, 4]]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "incompatible"):
         array_ops.boolean_mask(tensor, mask).eval()
 
@@ -345,7 +345,7 @@
   def testInvalid(self):
     x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
     axis = array_ops.placeholder(dtypes.int32)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "is out of valid range"):
         array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]})
@@ -954,7 +954,7 @@
 class SliceAssignTest(test_util.TensorFlowTestCase):
 
   def testInvalidSlice(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       foo = constant_op.constant([1, 2, 3])
       with self.assertRaisesRegexp(ValueError, "Sliced assignment"
                                    " is only supported for variables"):
@@ -1000,7 +1000,7 @@
     with self.assertRaisesRegexp(
         errors.FailedPreconditionError,
         "Attempting to use uninitialized value Variable"):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         v = variables.Variable([1, 2])
         sess.run(v[:].assign([1, 2]))
 
@@ -1019,7 +1019,7 @@
     too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
     too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
     v = resource_variable_ops.ResourceVariable(init_val)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(v.initializer)
       with self.assertRaises(ValueError):
         sess.run(v[:].assign(too_large_val))
@@ -1066,12 +1066,12 @@
 class SequenceMaskTest(test_util.TensorFlowTestCase):
 
   def testExceptions(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"):
         array_ops.sequence_mask([10, 20], [10, 20])
 
   def testOneDimensionalWithMaxlen(self):
-    with self.test_session():
+    with self.cached_session():
       res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
       self.assertAllEqual(res.get_shape(), [3, 5])
       self.assertAllEqual(
@@ -1081,7 +1081,7 @@
 
   @test_util.enable_c_shapes
   def testOneDimensionalDtypeWithoutMaxlen(self):
-    with self.test_session():
+    with self.cached_session():
       # test dtype and default maxlen:
       res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]),
                                     dtype=dtypes.float32)
@@ -1092,7 +1092,7 @@
 
   @test_util.enable_c_shapes
   def testOneDimensionalWithoutMaxlen(self):
-    with self.test_session():
+    with self.cached_session():
       res = array_ops.sequence_mask(
           constant_op.constant([0, 1, 4]))
       self.assertAllEqual(res.get_shape().as_list(), [3, 4])
@@ -1104,7 +1104,7 @@
 
   @test_util.enable_c_shapes
   def testTwoDimensional(self):
-    with self.test_session():
+    with self.cached_session():
       res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
       self.assertAllEqual(res.get_shape(), [1, 3, 5])
       self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [
@@ -1137,7 +1137,7 @@
           [[True, False, False, False, False], [True, True, True, False, False],
            [True, True, False, False, False]])
 
-    with self.test_session():
+    with self.cached_session():
       check_dtypes(dtypes.int32, dtypes.int32)
       check_dtypes(dtypes.int32, dtypes.int64)
       check_dtypes(dtypes.int64, dtypes.int32)
@@ -1216,7 +1216,7 @@
   # TODO(b/73086570): Reenable test.
   @unittest.skip("Test does not pass internally.")
   def testUnravelIndex(self):
-    with self.test_session():
+    with self.cached_session():
       for dtype in [dtypes.int32, dtypes.int64]:
         indices_1 = constant_op.constant(1621, dtype=dtype)
         dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
@@ -1237,13 +1237,13 @@
 class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
 
   def testSimple(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.constant(10)
       guarantee_a = array_ops.guarantee_const(a)
       self.assertEqual(10, guarantee_a.eval())
 
   def testVariables(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for use_resource in [False, True]:
         a = variable_scope.get_variable(
             "var_{}".format(use_resource), [],
@@ -1254,7 +1254,7 @@
         self.assertEqual(10.0, guarantee_a.eval())
 
   def testResourceRejection(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = variable_scope.get_variable(
           "resource_var", [],
           initializer=init_ops.constant_initializer(10.0),
diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py
index 51aa17b..dd4a90e 100644
--- a/tensorflow/python/kernel_tests/as_string_op_test.py
+++ b/tensorflow/python/kernel_tests/as_string_op_test.py
@@ -32,7 +32,7 @@
         0, 1, -1, 0.5, 0.25, 0.125, float("INF"), float("NAN"), float("-INF")
     ]
 
-    with self.test_session():
+    with self.cached_session():
       for dtype in (dtypes.float32, dtypes.float64):
         input_ = array_ops.placeholder(dtype)
 
@@ -84,7 +84,7 @@
     int_inputs_ = [0, -1, 1, -128, 127, -101, 101, -0]
     s = lambda strs: [x.decode("ascii") for x in strs]
 
-    with self.test_session():
+    with self.cached_session():
       for dtype in (dtypes.int32, dtypes.int64, dtypes.int8):
         input_ = array_ops.placeholder(dtype)
 
@@ -117,7 +117,7 @@
     # testing int8
     s = lambda strs: [x.decode("ascii") for x in strs]
 
-    with self.test_session():
+    with self.cached_session():
       input_ = array_ops.placeholder(dtypes.int32)
       int_inputs_ = [np.iinfo(np.int32).min, np.iinfo(np.int32).max]
       output = string_ops.as_string(input_)
@@ -133,7 +133,7 @@
   def testHalfInt(self):
     s = lambda strs: [x.decode("ascii") for x in strs]
 
-    with self.test_session():
+    with self.cached_session():
       input_ = array_ops.placeholder(dtypes.int16)
       int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max]
       output = string_ops.as_string(input_)
@@ -144,7 +144,7 @@
     bool_inputs_ = [False, True]
     s = lambda strs: [x.decode("ascii") for x in strs]
 
-    with self.test_session():
+    with self.cached_session():
       for dtype in (dtypes.bool,):
         input_ = array_ops.placeholder(dtype)
 
@@ -159,7 +159,7 @@
     ]
     complex_inputs_ = [(x + (x + 1) * 1j) for x in float_inputs_]
 
-    with self.test_session():
+    with self.cached_session():
       for dtype in (dtypes.complex64, dtypes.complex128):
         input_ = array_ops.placeholder(dtype)
 
diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py
index b98e5fd..6b16fca 100644
--- a/tensorflow/python/kernel_tests/atrous_convolution_test.py
+++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py
@@ -263,7 +263,7 @@
     self.assertLess(err, err_tolerance)
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       for padding in ["SAME", "VALID"]:
         for rate_width in range(1, 3):
           for rate_height in range(1, 3):
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
index fb74698..1e09ba5 100644
--- a/tensorflow/python/kernel_tests/attention_ops_test.py
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -84,7 +84,7 @@
         image_ops.extract_glimpse(t_cols_4d, t1, t2), [0, 2, 1, 3]))
 
     # Evaluate the TensorFlow Graph.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols])
 
     # Check dimensions of returned glimpse.
@@ -118,7 +118,7 @@
   def testEmptyTensor(self):
     empty_image = np.zeros((0, 4, 3, 0))
     offsets = np.zeros((0, 2))
-    with self.test_session():
+    with self.cached_session():
       result = image_ops.extract_glimpse(empty_image, [1, 1], offsets)
       self.assertAllEqual(
           np.zeros(
diff --git a/tensorflow/python/kernel_tests/barrier_ops_test.py b/tensorflow/python/kernel_tests/barrier_ops_test.py
index 7f49c63..4d36b3a 100644
--- a/tensorflow/python/kernel_tests/barrier_ops_test.py
+++ b/tensorflow/python/kernel_tests/barrier_ops_test.py
@@ -67,7 +67,7 @@
       """, b.barrier_ref.op.node_def)
 
   def testInsertMany(self):
-    with self.test_session():
+    with self.cached_session():
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       size_t = b.ready_size()
@@ -83,7 +83,7 @@
       self.assertEquals(size_t.eval(), [3])
 
   def testInsertManyEmptyTensor(self):
-    with self.test_session():
+    with self.cached_session():
       error_message = ("Empty tensors are not supported, but received shape "
                        r"\'\(0,\)\' at index 1")
       with self.assertRaisesRegexp(ValueError, error_message):
@@ -91,7 +91,7 @@
             (dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
 
   def testInsertManyEmptyTensorUnknown(self):
-    with self.test_session():
+    with self.cached_session():
       b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
       size_t = b.ready_size()
       self.assertEqual([], size_t.get_shape())
@@ -103,7 +103,7 @@
         insert_0_op.run()
 
   def testTakeMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       size_t = b.ready_size()
@@ -128,7 +128,7 @@
       self.assertEqual(values_1_val[idx], v1)
 
   def testTakeManySmallBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       size_t = b.ready_size()
@@ -192,7 +192,7 @@
         insert_1_3_op.run()
 
   def testUseBarrierWithShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B")
       size_t = b.ready_size()
@@ -221,7 +221,7 @@
       self.assertAllEqual(values_1_val[idx], v1)
 
   def testParallelInsertMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(dtypes.float32, shapes=())
       size_t = b.ready_size()
       keys = [str(x).encode("ascii") for x in range(10)]
@@ -241,7 +241,7 @@
       self.assertEqual(values_val[idx], v)
 
   def testParallelTakeMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(dtypes.float32, shapes=())
       size_t = b.ready_size()
       keys = [str(x).encode("ascii") for x in range(10)]
@@ -275,7 +275,7 @@
         zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)])
 
   def testBlockingTakeMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(dtypes.float32, shapes=())
       keys = [str(x).encode("ascii") for x in range(10)]
       values = [float(x) for x in range(10)]
@@ -297,7 +297,7 @@
       t.join()
 
   def testParallelInsertManyTakeMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.int64), shapes=((), (2,)))
       num_iterations = 100
@@ -376,7 +376,7 @@
         self.assertAllEqual(taken_i["values_1"], expected_values_1)
 
   def testClose(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       size_t = b.ready_size()
@@ -434,7 +434,7 @@
         sess.run(take_t[0])
 
   def testCancel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       size_t = b.ready_size()
@@ -487,7 +487,7 @@
         sess.run(take_t[0])
 
   def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
       take_t = b.take_many(1, allow_small_batch=True)
@@ -500,7 +500,7 @@
     self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True)
 
   def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.int64), shapes=((), (2,)))
       num_iterations = 50
@@ -576,7 +576,7 @@
     self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True)
 
   def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = data_flow_ops.Barrier(
           (dtypes.float32, dtypes.int64), shapes=((), (2,)))
       num_iterations = 100
@@ -676,7 +676,7 @@
     self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True)
 
   def testIncompatibleSharedBarrierErrors(self):
-    with self.test_session():
+    with self.cached_session():
       # Do component types and shapes.
       b_a_1 = data_flow_ops.Barrier(
           (dtypes.float32,), shapes=(()), shared_name="b_a")
diff --git a/tensorflow/python/kernel_tests/base64_ops_test.py b/tensorflow/python/kernel_tests/base64_ops_test.py
index be96f45..1b39994 100644
--- a/tensorflow/python/kernel_tests/base64_ops_test.py
+++ b/tensorflow/python/kernel_tests/base64_ops_test.py
@@ -48,7 +48,7 @@
     return base64_msg
 
   def _RunTest(self, msg, pad):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if pad:
         encoded, decoded = sess.run([self._encoded_t, self._decoded_t],
                                     feed_dict={self._msg: msg})
@@ -92,7 +92,7 @@
         encoded = string_ops.encode_base64(msg, pad=pad)
         decoded = string_ops.decode_base64(encoded)
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           encoded_value, decoded_value = sess.run([encoded, decoded])
 
         self.assertEqual(encoded_value.shape, msg.shape)
@@ -102,7 +102,7 @@
     def try_decode(enc):
       self._decoded_f.eval(feed_dict={self._encoded_f: enc})
 
-    with self.test_session():
+    with self.cached_session():
       # Invalid length.
       msg = np.random.bytes(99)
       enc = base64.urlsafe_b64encode(msg)
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index 987a6ff..e651fa0 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -174,7 +174,7 @@
                         numeric_gradient_type=None):
     z = np_func(x, y)
     zs = list(z.shape)
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       if x.dtype in (np.float32, np.float64):
@@ -195,7 +195,7 @@
                         numeric_gradient_type=None):
     z = np_func(x, y)
     zs = list(z.shape)
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       if x.dtype in (np.float32, np.float64):
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 8e7ae89..7dd3479 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -86,7 +86,7 @@
 
   def testString(self):
     params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
-    with self.test_session():
+    with self.cached_session():
       indices_tf = constant_op.constant([1])
       self.assertAllEqual([[b"qwer", b"uiop"]],
                           array_ops.batch_gather(params, indices_tf).eval())
diff --git a/tensorflow/python/kernel_tests/batchtospace_op_test.py b/tensorflow/python/kernel_tests/batchtospace_op_test.py
index 6143cd3..03f3f64 100644
--- a/tensorflow/python/kernel_tests/batchtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/batchtospace_op_test.py
@@ -60,7 +60,7 @@
           array_ops.depth_to_space(
               array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
           [3, 1, 2, 0])
-      with self.test_session():
+      with self.cached_session():
         self.assertAllEqual(y1.eval(), y2.eval())
 
 
@@ -235,7 +235,7 @@
   # Check the gradients.
   def _checkGrad(self, x, crops, block_size):
     assert 4 == x.ndim
-    with self.test_session():
+    with self.cached_session():
       tf_x = ops.convert_to_tensor(x)
       tf_y = self.batch_to_space(tf_x, crops, block_size)
       epsilon = 1e-5
@@ -293,7 +293,7 @@
     block_shape = np.array(block_shape)
     crops = constant_op.constant(
         np.array(crops).reshape((len(block_shape), 2)), crops_dtype)
-    with self.test_session():
+    with self.cached_session():
       tf_x = ops.convert_to_tensor(x)
       tf_y = array_ops.batch_to_space_nd(tf_x, block_shape, crops)
       epsilon = 1e-5
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
index 3305e55..3ec820a 100644
--- a/tensorflow/python/kernel_tests/bcast_ops_test.py
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -28,11 +28,11 @@
 class BcastOpsTest(test.TestCase):
 
   def _GetBroadcastShape(self, xs, ys):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       return sess.run(broadcast_args(xs, ys))
 
   def _GetGradientArgs(self, xs, ys):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       return sess.run(broadcast_gradient_args(xs, ys))
 
   def testBasic(self):
diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py
index 16fdeda..92d2146 100644
--- a/tensorflow/python/kernel_tests/betainc_op_test.py
+++ b/tensorflow/python/kernel_tests/betainc_op_test.py
@@ -47,7 +47,7 @@
       tf_b_s = constant_op.constant(b_s, dtype=dtype)
       tf_x_s = constant_op.constant(x_s, dtype=dtype)
       tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
-      with self.test_session():
+      with self.cached_session():
         tf_out = tf_out_t.eval()
       scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
 
@@ -60,13 +60,13 @@
       # Test out-of-range values (most should return nan output)
       combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
       a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
-      with self.test_session():
+      with self.cached_session():
         tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
       scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
       self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
 
       # Test broadcasting between scalars and other shapes
-      with self.test_session():
+      with self.cached_session():
         self.assertAllCloseAccordingToType(
             special.betainc(0.1, b_s, x_s).astype(np_dt),
             math_ops.betainc(0.1, b_s, x_s).eval(),
@@ -96,7 +96,7 @@
       with self.assertRaisesRegexp(ValueError, "must be equal"):
         math_ops.betainc(0.5, [0.5], [[0.5]])
 
-      with self.test_session():
+      with self.cached_session():
         with self.assertRaisesOpError("Shapes of .* are inconsistent"):
           a_p = array_ops.placeholder(dtype)
           b_p = array_ops.placeholder(dtype)
@@ -140,7 +140,7 @@
     self._testBetaInc(a_s, b_s, x_s, dtypes.float32)
 
   def testBetaIncFpropAndBpropAreNeverNAN(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       space = np.logspace(-8, 5).tolist()
       space_x = np.linspace(1e-16, 1 - 1e-16).tolist()
       ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x)))
@@ -161,7 +161,7 @@
 
   def testBetaIncGrads(self):
     err_tolerance = 1e-3
-    with self.test_session():
+    with self.cached_session():
       # Test gradient
       ga_s = np.abs(np.random.randn(2, 2) * 30)  # in (0, infty)
       gb_s = np.abs(np.random.randn(2, 2) * 30)  # in (0, infty)
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 2767df1..8a58b3f 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -93,7 +93,7 @@
 
   def test_negative(self):
     # unsorted_segment_sum will only report InvalidArgumentError on CPU
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors.InvalidArgumentError):
         math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
 
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
index 4f92ab0..2044678 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/BUILD
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -74,3 +74,16 @@
         "//tensorflow/python:resources",
     ],
 )
+
+tf_py_test(
+    name = "quantile_ops_test",
+    size = "small",
+    srcs = ["quantile_ops_test.py"],
+    additional_deps = [
+        "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+        "//tensorflow/python:boosted_trees_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:resources",
+    ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 4e31b1e..dee9610 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -30,7 +30,7 @@
 
   def testCachedPredictionOnEmptyEnsemble(self):
     """Tests that prediction on a dummy ensemble does not fail."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create a dummy ensemble.
       tree_ensemble = boosted_trees_ops.TreeEnsemble(
           'ensemble', serialized_proto='')
@@ -63,7 +63,7 @@
 
   def testNoCachedPredictionButTreeExists(self):
     """Tests that predictions are updated once trees are added."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -129,7 +129,7 @@
 
   def testCachedPredictionIsCurrent(self):
     """Tests that prediction based on previous node in the tree works."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -201,7 +201,7 @@
 
   def testCachedPredictionFromTheSameTree(self):
     """Tests that prediction based on previous node in the tree works."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -315,7 +315,7 @@
 
   def testCachedPredictionFromPreviousTree(self):
     """Tests the predictions work when we have cache from previous trees."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -447,7 +447,7 @@
 
   def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
     """Tests that prediction based on previous node in the tree works."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -577,7 +577,7 @@
 
   def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
     """Tests that prediction based on previous node in the tree works."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -722,7 +722,7 @@
 
   def testCachedPredictionTheWholeTreeWasPruned(self):
     """Tests that prediction based on previous node in the tree works."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -794,7 +794,7 @@
 
   def testPredictionOnEmptyEnsemble(self):
     """Tests that prediction on a empty ensemble does not fail."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create an empty ensemble.
       tree_ensemble = boosted_trees_ops.TreeEnsemble(
           'ensemble', serialized_proto='')
@@ -816,7 +816,7 @@
 
   def testPredictionMultipleTree(self):
     """Tests the predictions work when we have multiple trees."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -930,7 +930,7 @@
 
   def testContribsMultipleTree(self):
     """Tests that the contribs work when we have multiple trees."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
new file mode 100644
index 0000000..c71b8df
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for checking quantile related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
+from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
+from tensorflow.python.platform import googletest
+
+
+class QuantileOpsTest(test_util.TensorFlowTestCase):
+
+  def create_resource(self, name, eps, max_elements, num_streams=1):
+    quantile_accumulator_handle = resource_handle_op(
+        container="", shared_name=name, name=name)
+    create_op = boosted_trees_ops.create_quantile_stream_resource(
+        quantile_accumulator_handle,
+        epsilon=eps,
+        max_elements=max_elements,
+        num_streams=num_streams)
+    is_initialized_op = resource_initialized(quantile_accumulator_handle)
+    resources.register_resource(quantile_accumulator_handle, create_op,
+                                is_initialized_op)
+    return quantile_accumulator_handle
+
+  def setUp(self):
+    """Sets up the quantile ops test as follows.
+
+    Create a batch of 6 examples having 2 features
+    The data looks like this
+    | Instance | instance weights | Feature 0 | Feature 1
+    | 0        |     10           |   1.2     |   2.3
+    | 1        |     1            |   12.1    |   1.2
+    | 2        |     1            |   0.3     |   1.1
+    | 3        |     1            |   0.5     |   2.6
+    | 4        |     1            |   0.6     |   3.2
+    | 5        |     1            |   2.2     |   0.8
+    """
+
+    self._feature_0 = constant_op.constant(
+        [[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
+    self._feature_1 = constant_op.constant(
+        [[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
+    self._feature_0_boundaries = constant_op.constant(
+        [0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
+    self._feature_1_boundaries = constant_op.constant(
+        [0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
+    self._feature_0_quantiles = constant_op.constant(
+        [[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
+    self._feature_1_quantiles = constant_op.constant(
+        [[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
+
+    self._example_weights = constant_op.constant(
+        [10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
+
+    self.eps = 0.01
+    self.max_elements = 1 << 16
+    self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
+
+  def testBasicQuantileBucketsSingleResource(self):
+    with self.test_session() as sess:
+      quantile_accumulator_handle = self.create_resource("floats", self.eps,
+                                                         self.max_elements, 2)
+      resources.initialize_resources(resources.shared_resources()).run()
+      summaries = boosted_trees_ops.make_quantile_summaries(
+          [self._feature_0, self._feature_1], self._example_weights,
+          epsilon=self.eps)
+      summary_op = boosted_trees_ops.quantile_add_summaries(
+          quantile_accumulator_handle, summaries)
+      flush_op = boosted_trees_ops.quantile_flush(
+          quantile_accumulator_handle, self.num_quantiles)
+      buckets = boosted_trees_ops.get_bucket_boundaries(
+          quantile_accumulator_handle, num_features=2)
+      quantiles = boosted_trees_ops.boosted_trees_bucketize(
+          [self._feature_0, self._feature_1], buckets)
+      sess.run(summary_op)
+      sess.run(flush_op)
+      self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
+      self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
+
+      self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+      self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+  def testBasicQuantileBucketsMultipleResources(self):
+    with self.test_session() as sess:
+      quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
+                                                           self.max_elements)
+      quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
+                                                           self.max_elements)
+      resources.initialize_resources(resources.shared_resources()).run()
+      summaries = boosted_trees_ops.make_quantile_summaries(
+          [self._feature_0, self._feature_1], self._example_weights,
+          epsilon=self.eps)
+      summary_op_0 = boosted_trees_ops.quantile_add_summaries(
+          quantile_accumulator_handle_0,
+          [summaries[0]])
+      summary_op_1 = boosted_trees_ops.quantile_add_summaries(
+          quantile_accumulator_handle_1,
+          [summaries[1]])
+      flush_op_0 = boosted_trees_ops.quantile_flush(
+          quantile_accumulator_handle_0, self.num_quantiles)
+      flush_op_1 = boosted_trees_ops.quantile_flush(
+          quantile_accumulator_handle_1, self.num_quantiles)
+      bucket_0 = boosted_trees_ops.get_bucket_boundaries(
+          quantile_accumulator_handle_0, num_features=1)
+      bucket_1 = boosted_trees_ops.get_bucket_boundaries(
+          quantile_accumulator_handle_1, num_features=1)
+      quantiles = boosted_trees_ops.boosted_trees_bucketize(
+          [self._feature_0, self._feature_1], bucket_0 + bucket_1)
+      sess.run([summary_op_0, summary_op_1])
+      sess.run([flush_op_0, flush_op_1])
+      self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
+      self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
+
+      self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+      self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
index d5f0c22..65bb9ab 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
@@ -31,7 +31,7 @@
   """Tests resource_ops."""
 
   def testCreate(self):
-    with self.test_session():
+    with self.cached_session():
       ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
       resources.initialize_resources(resources.shared_resources()).run()
       stamp_token = ensemble.get_stamp_token()
@@ -44,7 +44,7 @@
       self.assertAllEqual([0, 1], nodes_range.eval())
 
   def testCreateWithProto(self):
-    with self.test_session():
+    with self.cached_session():
       ensemble_proto = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
@@ -161,7 +161,7 @@
       self.assertAllEqual([16, 19], nodes_range.eval())
 
   def testSerializeDeserialize(self):
-    with self.test_session():
+    with self.cached_session():
       # Initialize.
       ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
       resources.initialize_resources(resources.shared_resources()).run()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 568e695..09e9cfa 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -30,7 +30,7 @@
 
   def testCalculateBestGainsWithoutRegularization(self):
     """Testing Gain calculation without any regularization."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -78,7 +78,7 @@
 
   def testCalculateBestGainsWithL2(self):
     """Testing Gain calculation with L2."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -126,7 +126,7 @@
 
   def testCalculateBestGainsWithL1(self):
     """Testing Gain calculation with L1."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -177,7 +177,7 @@
 
   def testCalculateBestGainsWithTreeComplexity(self):
     """Testing Gain calculation with L2."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -229,7 +229,7 @@
 
   def testCalculateBestGainsWithMinNodeWeight(self):
     """Testing Gain calculation without any regularization."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -276,7 +276,7 @@
 
   def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
     """Testing Gain calculation without any regularization."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       max_splits = 7
       node_id_range = [1, 3]  # node 1 through 2 will be processed.
       stats_summary_list = [
@@ -329,7 +329,7 @@
 
   def testMakeStatsSummarySimple(self):
     """Simple test for MakeStatsSummary."""
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose([[[[1., 5.], [2., 6.]], [[3., 7.], [4., 8.]]]],
                           boosted_trees_ops.make_stats_summary(
                               node_ids=[0, 0, 1, 1],
@@ -341,7 +341,7 @@
 
   def testMakeStatsSummaryAccumulate(self):
     """Tests that Summary actually accumulates."""
-    with self.test_session():
+    with self.cached_session():
       max_splits = 3
       num_buckets = 4
       node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -363,7 +363,7 @@
 
   def testMakeStatsSummaryMultipleFeatures(self):
     """Tests that MakeStatsSummary works for multiple features."""
-    with self.test_session():
+    with self.cached_session():
       max_splits = 3
       num_buckets = 4
       node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -392,7 +392,7 @@
           result.eval())
 
   def _verify_precision(self, length):
-    with self.test_session():
+    with self.cached_session():
       max_splits = 1
       num_buckets = 1
       node_ids = array_ops.fill([length], 0)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index d552402..ea02282 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -32,7 +32,7 @@
 
   def testGrowWithEmptyEnsemble(self):
     """Test growing an empty ensemble."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
       tree_ensemble_handle = tree_ensemble.resource_handle
@@ -141,7 +141,7 @@
 
   def testBiasCenteringOnEmptyEnsemble(self):
     """Test growing with bias centering on an empty ensemble."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
       tree_ensemble_handle = tree_ensemble.resource_handle
@@ -184,7 +184,7 @@
 
   def testGrowExistingEnsembleTreeNotFinalized(self):
     """Test growing an existing ensemble with the last tree not finalized."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -368,7 +368,7 @@
 
   def testGrowExistingEnsembleTreeFinalized(self):
     """Test growing an existing ensemble with the last tree finalized."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -517,7 +517,7 @@
 
   def testPrePruning(self):
     """Test growing an existing ensemble with pre-pruning."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge("""
         trees {
@@ -673,7 +673,7 @@
 
   def testMetadataWhenCantSplitDueToEmptySplits(self):
     """Test that the metadata is updated even though we can't split."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
@@ -784,7 +784,7 @@
 
   def testMetadataWhenCantSplitDuePrePruning(self):
     """Test metadata is updated correctly when no split due to prepruning."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
@@ -919,7 +919,7 @@
 
   def testPostPruningOfSomeNodes(self):
     """Test growing an ensemble with post-pruning."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       tree_ensemble = boosted_trees_ops.TreeEnsemble(
@@ -1253,7 +1253,7 @@
 
   def testPostPruningOfAllNodes(self):
     """Test growing an ensemble with post-pruning, with all nodes are pruned."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       # Create empty ensemble.
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
@@ -1436,7 +1436,7 @@
 
   def testPostPruningChangesNothing(self):
     """Test growing an ensemble with post-pruning with all gains >0."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       # Create empty ensemble.
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index 6a1bd95..bd2339f 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -21,8 +21,10 @@
 
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
 from tensorflow.python.platform import test as test_lib
 
 
@@ -81,5 +83,47 @@
         # check shape inference when shape input is constant
         self.assertAllEqual(shape, v_np.shape)
 
+  def testGradientForScalar(self):
+    # TODO(alextp): There is a bug with broadcast_to on GPU from scalars,
+    # hence we make this test cpu-only.
+    with ops.device("cpu:0"):
+      x = constant_op.constant(1, dtype=dtypes.float32)
+      v = array_ops.broadcast_to(x, [2, 4, 3])
+      out = 2 * v
+      with self.test_session():
+        err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+                                                      out, out.get_shape())
+    self.assertLess(err, 1e-4)
+
+  def testGradientWithSameRank(self):
+    x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)),
+                             dtype=dtypes.float32)
+    v = array_ops.broadcast_to(x, [2, 5, 3])
+    out = 2 * v
+    with self.test_session():
+      err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+                                                    out, out.get_shape())
+    self.assertLess(err, 1e-4)
+
+  def testGradientWithIncreasingRank(self):
+    x = constant_op.constant([[1], [2]],
+                             dtype=dtypes.float32)
+    v = array_ops.broadcast_to(x, [5, 2, 3])
+    out = 2 * v
+    with self.test_session():
+      err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+                                                    out, out.get_shape())
+    self.assertLess(err, 1e-4)
+
+  def testGradientWithBroadcastAllDimensions(self):
+    x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
+    v = array_ops.broadcast_to(x, [5, 4, 6])
+    out = 2 * v
+    with self.test_session():
+      err = gradient_checker.compute_gradient_error(x, x.get_shape(),
+                                                    out, out.get_shape())
+    self.assertLess(err, 1e-4)
+
+
 if __name__ == "__main__":
   test_lib.main()
diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
index 28b3dc4..b19077d 100644
--- a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
+++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
@@ -38,7 +38,7 @@
   TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
 
   def testTrueCandidates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       indices = constant_op.constant([0, 0, 1, 1, 2, 2])
       true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3])
       true_candidates_matrix = array_ops.reshape(
@@ -50,7 +50,7 @@
     self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
 
   def testSampledCandidates(self):
-    with self.test_session():
+    with self.cached_session():
       true_classes = constant_op.constant(
           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
       sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -62,7 +62,7 @@
     self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
 
   def testTrueLogExpectedCount(self):
-    with self.test_session():
+    with self.cached_session():
       true_classes = constant_op.constant(
           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
       _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -77,7 +77,7 @@
                      [self.BATCH_SIZE, self.NUM_TRUE])
 
   def testSampledLogExpectedCount(self):
-    with self.test_session():
+    with self.cached_session():
       true_classes = constant_op.constant(
           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
       _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler(  # pylint: disable=line-too-long
@@ -90,7 +90,7 @@
     self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
 
   def testAccidentalHits(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       true_classes = constant_op.constant(
           [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
       sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -109,7 +109,7 @@
   def testSeed(self):
 
     def draw(seed):
-      with self.test_session():
+      with self.cached_session():
         true_classes = constant_op.constant(
             [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
         sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler(
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 214d5cb..c90520e 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -174,7 +174,7 @@
     self.assertAllEqual(np.isnan(self._cast(np.nan, np.float64, True)), True)
 
   def _OpError(self, x, dtype, err):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError(err):
         math_ops.cast(x, dtype).eval()
 
@@ -182,7 +182,7 @@
     self._OpError(np.arange(0, 10), dtypes.string, "Cast.*int64.*string.*")
 
   def testCastToTypeOfVariable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = variables.Variable(5, dtype=dtypes.float32)
       y = variables.Variable(True, dtype=dtypes.bool)
       cast = math_ops.cast(y, x.dtype)
@@ -193,7 +193,7 @@
     t = [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
     for src_t in t:
       for dst_t in t:
-        with self.test_session():
+        with self.cached_session():
           x = constant_op.constant(1.0, src_t)
           z = array_ops.identity(x)
           y = math_ops.cast(z, dst_t)
@@ -209,7 +209,7 @@
     shape = constant_op.constant([3], dtypes.int64)
     st = sparse_tensor.SparseTensor(indices, values, shape)
     st_cast = math_ops.cast(st, dtypes.float32)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(st_cast.indices.eval(), [[0], [1], [2]])
       self.assertAllEqual(st_cast.values.eval(),
                           np.array([1, 2, 3], np.float32))
@@ -221,7 +221,7 @@
   def testSaturate(self):
     in_types = dtypes.float32,
     out_types = dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.float32
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for in_type in in_types:
         for out_type in out_types:
           lo, hi = in_type.min, in_type.max
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 05f998d..27a674e 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gradients
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
@@ -116,7 +117,7 @@
       check_ops.assert_equal(static_big, static_small, message="fail")
 
   def test_raises_when_greater_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       small = array_ops.placeholder(dtypes.int32, name="small")
       big = array_ops.placeholder(dtypes.int32, name="big")
       with ops.control_dependencies(
@@ -194,7 +195,7 @@
       check_ops.assert_equal(static_big, static_small, message="fail")
 
   def test_raises_when_less_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       small = array_ops.placeholder(dtypes.int32, name="small")
       big = array_ops.placeholder(dtypes.int32, name="big")
       with ops.control_dependencies([check_ops.assert_equal(small, big)]):
@@ -271,30 +272,28 @@
 
   @test_util.run_in_graph_and_eager_modes
   def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1, 1, 1], name="small")
-      big = constant_op.constant([10, 10], name="big")
-      # The exception in eager and non-eager mode is different because
-      # eager mode relies on shape check done as part of the C++ op, while
-      # graph mode does shape checks when creating the `Operation` instance.
-      with self.assertRaisesRegexp(
-          (ValueError, errors.InvalidArgumentError),
-          (r"Incompatible shapes: \[3\] vs. \[2\]|"
-           r"Dimensions must be equal, but are 3 and 2")):
-        with ops.control_dependencies(
-            [check_ops.assert_none_equal(small, big)]):
-          out = array_ops.identity(small)
-        self.evaluate(out)
+    small = constant_op.constant([1, 1, 1], name="small")
+    big = constant_op.constant([10, 10], name="big")
+    # The exception in eager and non-eager mode is different because
+    # eager mode relies on shape check done as part of the C++ op, while
+    # graph mode does shape checks when creating the `Operation` instance.
+    with self.assertRaisesRegexp(
+        (ValueError, errors.InvalidArgumentError),
+        (r"Incompatible shapes: \[3\] vs. \[2\]|"
+         r"Dimensions must be equal, but are 3 and 2")):
+      with ops.control_dependencies(
+          [check_ops.assert_none_equal(small, big)]):
+        out = array_ops.identity(small)
+      self.evaluate(out)
 
   @test_util.run_in_graph_and_eager_modes
   def test_doesnt_raise_when_both_empty(self):
-    with self.test_session():
-      larry = constant_op.constant([])
-      curly = constant_op.constant([])
-      with ops.control_dependencies(
-          [check_ops.assert_none_equal(larry, curly)]):
-        out = array_ops.identity(larry)
-      self.evaluate(out)
+    larry = constant_op.constant([])
+    curly = constant_op.constant([])
+    with ops.control_dependencies(
+        [check_ops.assert_none_equal(larry, curly)]):
+      out = array_ops.identity(larry)
+    self.evaluate(out)
 
   def test_returns_none_with_eager(self):
     with context.eager_mode():
@@ -821,6 +820,18 @@
     with self.test_session() as sess:
       sess.run(derived, feed_dict={placeholder: feed_val})
 
+  def testGradient(self):
+    placeholder = array_ops.placeholder(dtypes.float32)
+    derived = check_ops.ensure_shape(placeholder, (None, None))
+    gradient = gradients.gradients(derived, placeholder)
+
+    feed_val = [[4.0], [-1.0]]
+    with self.test_session() as sess:
+      gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
+
+    expected = [[1.0], [1.0]]
+    self.assertAllEqual(gradient_values, expected)
+
 
 class EnsureShapeBenchmark(test.Benchmark):
 
@@ -905,7 +916,7 @@
         self.evaluate(array_ops.identity(tensor))
 
   def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 1
       with ops.control_dependencies(
@@ -923,7 +934,7 @@
       self.evaluate(array_ops.identity(tensor))
 
   def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 0
       with ops.control_dependencies(
@@ -940,7 +951,7 @@
         self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 0
       with ops.control_dependencies(
@@ -957,7 +968,7 @@
       self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 1
       with ops.control_dependencies(
@@ -974,7 +985,7 @@
         self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 2
       with ops.control_dependencies(
@@ -989,7 +1000,7 @@
       check_ops.assert_rank(tensor, np.array([], dtype=np.int32))
 
   def test_raises_if_rank_is_not_scalar_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant(
           [1, 2], dtype=dtypes.float32, name="my_tensor")
       rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
@@ -1006,7 +1017,7 @@
       check_ops.assert_rank(tensor, .5)
 
   def test_raises_if_rank_is_not_integer_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant(
           [1, 2], dtype=dtypes.float32, name="my_tensor")
       rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -1029,7 +1040,7 @@
         self.evaluate(array_ops.identity(tensor_rank0))
 
   def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
@@ -1045,7 +1056,7 @@
         self.evaluate(array_ops.identity(tensor_rank0))
 
   def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
       for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
         with ops.control_dependencies([
@@ -1061,7 +1072,7 @@
         self.evaluate(array_ops.identity(tensor_rank1))
 
   def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
       for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
         with ops.control_dependencies([
@@ -1079,7 +1090,7 @@
         self.evaluate(array_ops.identity(tensor_rank1))
 
   def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
       with ops.control_dependencies([
           check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
@@ -1098,7 +1109,7 @@
       check_ops.assert_rank_in(tensor, desired_ranks)
 
   def test_raises_if_rank_is_not_scalar_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant(
           (42, 43), dtype=dtypes.float32, name="my_tensor")
       desired_ranks = (
@@ -1120,7 +1131,7 @@
       check_ops.assert_rank_in(tensor, (1, .5,))
 
   def test_raises_if_rank_is_not_integer_dynamic(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant(
           (42, 43), dtype=dtypes.float32, name="my_tensor")
       rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
@@ -1143,7 +1154,7 @@
         self.evaluate(array_ops.identity(tensor))
 
   def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 1
       with ops.control_dependencies(
@@ -1160,7 +1171,7 @@
       self.evaluate(array_ops.identity(tensor))
 
   def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 0
       with ops.control_dependencies(
@@ -1176,7 +1187,7 @@
       self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 0
       with ops.control_dependencies(
@@ -1192,7 +1203,7 @@
       self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 1
       with ops.control_dependencies(
@@ -1209,7 +1220,7 @@
         self.evaluate(array_ops.identity(tensor))
 
   def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
       desired_rank = 2
       with ops.control_dependencies(
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
index 7f147ba..51611b7 100644
--- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py
+++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
@@ -57,7 +57,7 @@
         new_vocab_offset=0)
     expected_remapping = range(0, 3)
     expected_num_present = 3
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_remapping, remapping.eval())
       self.assertAllEqual(expected_num_present, num_present.eval())
 
@@ -70,7 +70,7 @@
         new_vocab_offset=0)
     expected_remapping = [2, 0, 1]
     expected_num_present = 3
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_remapping, remapping.eval())
       self.assertAllEqual(expected_num_present, num_present.eval())
 
@@ -83,7 +83,7 @@
         new_vocab_offset=1)
     expected_remapping = [0]
     expected_num_present = 1
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_remapping, remapping.eval())
       self.assertAllEqual(expected_num_present, num_present.eval())
 
@@ -98,7 +98,7 @@
         old_vocab_size=2)
     expected_remapping = [-1, 0, 1]
     expected_num_present = 2
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_remapping, remapping.eval())
       self.assertAllEqual(expected_num_present, num_present.eval())
 
@@ -122,7 +122,7 @@
       self.old_tensor_name = 'some_scope/matrix'
 
     save = saver.Saver([matrix])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables.global_variables_initializer().run()
       self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
       save.save(sess, self.bundle_file)
@@ -140,7 +140,7 @@
         initializing_values=[],
         num_rows=2,
         num_cols=self.old_num_cols)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(self.matrix_value[row_remapping],
                           remapped_matrix.eval())
 
@@ -155,7 +155,7 @@
         initializing_values=[],
         num_rows=len(row_remapping),
         num_cols=len(col_remapping))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
                           remapped_matrix.eval())
 
@@ -170,7 +170,7 @@
         initializing_values=[],
         num_rows=len(row_remapping),
         num_cols=len(col_remapping))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
                           remapped_matrix.eval())
 
@@ -189,7 +189,7 @@
     expected_remapped_matrix = np.reshape(
         [33, init_val, init_val, init_val, 1, init_val], [3, 2])
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
 
   def test_load_and_remap_all_missing_rows(self):
@@ -204,7 +204,7 @@
         initializing_values=initializing_values,
         num_rows=num_rows,
         num_cols=self.old_num_cols)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(
           np.reshape(initializing_values, (num_rows, self.old_num_cols)),
           remapped_matrix.eval())
@@ -222,7 +222,7 @@
         initializing_values=initializing_values,
         num_rows=num_rows,
         num_cols=num_cols)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(
           np.reshape(initializing_values, (num_rows, num_cols)),
           remapped_matrix.eval())
@@ -243,7 +243,7 @@
         initializing_values=[],
         num_rows=len(invalid_remapping),
         num_cols=self.old_num_cols)
-    with self.test_session(), self.assertRaises(errors.UnimplementedError):
+    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
       remapped_matrix.eval()
 
     # Invalid column remapping.
@@ -255,7 +255,7 @@
         initializing_values=[],
         num_rows=self.old_num_rows,
         num_cols=len(invalid_remapping))
-    with self.test_session(), self.assertRaises(errors.UnimplementedError):
+    with self.cached_session(), self.assertRaises(errors.UnimplementedError):
       remapped_matrix.eval()
 
   def test_load_and_remap_incorrect_initializing_values(self):
@@ -272,7 +272,7 @@
         initializing_values=[],
         num_rows=3,
         num_cols=2)
-    with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
       remapped_matrix.eval()
 
     remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
@@ -284,7 +284,7 @@
         initializing_values=[0] * 5,
         num_rows=3,
         num_cols=2)
-    with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+    with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
       remapped_matrix.eval()
 
 
@@ -306,7 +306,7 @@
         initializer=constant_op.constant(np_value, dtype=dtypes.float32),
         partitioner=partitioner)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
       save = saver.Saver([matrix])
       variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index 400d38b..bb7b645 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -27,6 +27,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.platform import test
 
 
@@ -38,7 +39,7 @@
     min_val = constant_op.constant([0.5, 0.5, 0.5, 0.5], dtype=dtypes.float32)
     max_val = constant_op.constant([3.5, 3.5, 3.5, 3.5], dtype=dtypes.float32)
     outputs_2 = clip_ops.clip_by_value(inputs, min_val, max_val)
-    with self.test_session():
+    with self.cached_session():
       error_1 = gradient_checker.compute_gradient_error(inputs, [4], outputs_1,
                                                         [4])
       self.assertLess(error_1, 1e-4)
@@ -138,7 +139,7 @@
 
   def testClipByValueNonFinite(self):
     # TODO(b/78016351): Enable test on GPU once the bug is fixed.
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
       np_ans = [float('NaN'), 4.0, -4.0]
       clip_value = 4.0
@@ -158,13 +159,19 @@
       ans = clip_ops.clip_by_norm(x, clip_norm)
       tf_ans = ans.eval()
 
-      clip_tensor = constant_op.constant(4.0)
       ans = clip_ops.clip_by_norm(x, clip_norm)
       tf_ans_tensor = ans.eval()
 
     self.assertAllClose(np_ans, tf_ans)
     self.assertAllClose(np_ans, tf_ans_tensor)
 
+  def testClipByNormGradientZeros(self):
+    with self.test_session(use_gpu=True):
+      x = array_ops.zeros([3])
+      b = clip_ops.clip_by_norm(x, 1.)
+      grad, = gradients_impl.gradients(b, x)
+      self.assertAllEqual(grad.eval(), [1., 1., 1.])
+
   def testClipByNormBadShape(self):
     with self.test_session(use_gpu=True):
       x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index c22934c..0e59ce6 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -383,7 +383,7 @@
         np.random.random_sample(x_shape).astype(np.float64)
         for x_shape in x_shapes
     ]
-    with self.test_session():
+    with self.cached_session():
       xs = [constant_op.constant(x_val) for x_val in x_vals]
       output = array_ops.concat(xs, 0)
       err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -397,7 +397,7 @@
         np.random.random_sample(x_shape).astype(np.float64)
         for x_shape in x_shapes
     ]
-    with self.test_session():
+    with self.cached_session():
       xs = [constant_op.constant(x_val) for x_val in x_vals]
       output = array_ops.concat(xs, 1)
       err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -411,7 +411,7 @@
         np.random.random_sample(x_shape).astype(np.float64)
         for x_shape in x_shapes
     ]
-    with self.test_session():
+    with self.cached_session():
       xs = [constant_op.constant(x_val) for x_val in x_vals]
       x_concat = array_ops.concat(xs, 0)
       output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -426,7 +426,7 @@
         np.random.random_sample(x_shape).astype(np.float64)
         for x_shape in x_shapes
     ]
-    with self.test_session():
+    with self.cached_session():
       xs = [constant_op.constant(x_val) for x_val in x_vals]
       x_concat = array_ops.concat(xs, 1)
       output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -441,7 +441,7 @@
         np.random.random_sample(x_shape).astype(np.float64)
         for x_shape in x_shapes
     ]
-    with self.test_session():
+    with self.cached_session():
       xs = [constant_op.constant(x_val) for x_val in x_vals]
       x_concat = array_ops.concat(xs, 2)
       output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -452,7 +452,7 @@
   def testIndexedSlicesConcatDim1Grad_UnknownInputDim(self):
     x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
     output_shape = [4, 11, 3]
-    with self.test_session():
+    with self.cached_session():
       x_1 = array_ops.placeholder(dtypes.float64)
       x_2 = array_ops.placeholder(dtypes.float64)
       x_3 = array_ops.placeholder(dtypes.float64)
@@ -473,13 +473,13 @@
   def testConcatTuple(self):
     c1 = np.random.rand(4, 4)
     c2 = np.random.rand(4, 4)
-    with self.test_session():
+    with self.cached_session():
       concat_list_t = array_ops.concat([c1, c2], 0)
       concat_tuple_t = array_ops.concat((c1, c2), 0)
       self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
 
   def testConcatNoScalars(self):
-    with self.test_session():
+    with self.cached_session():
       scalar = constant_op.constant(7)
       dim = array_ops.placeholder(dtypes.int32)
       with self.assertRaisesRegexp(
@@ -554,7 +554,7 @@
 
   def _testGradientsForAxis(
       self, inp_tensors, axis, output_shape, feed_dict=None):
-    with self.test_session():
+    with self.cached_session():
       c = array_ops.concat(inp_tensors, axis)
       grad_inp = np.random.rand(*output_shape).astype("f")
       grad_tensor = constant_op.constant(
@@ -566,7 +566,7 @@
 
   def _testIndexedSlicesGradientsForAxis(
       self, inp_tensors, axis, output_shape, gather_indexes, feed_dict=None):
-    with self.test_session():
+    with self.cached_session():
       c = array_ops.gather(
           array_ops.concat(inp_tensors, axis), gather_indexes)
       grad_inp = np.random.rand(*output_shape).astype("f")
@@ -631,7 +631,7 @@
       self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
 
   def testNotVector(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       cdim = constant_op.constant(1, dtypes.int32)
       s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
       s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
@@ -641,7 +641,7 @@
         sess.run(off)
 
   def testConcatDimOutOfRange(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       cdim = constant_op.constant(4, dtypes.int32)
       s0 = constant_op.constant([2, 3, 5], dtypes.int32)
       s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -651,7 +651,7 @@
         sess.run(off)
 
   def testDimMismatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       cdim = constant_op.constant(1, dtypes.int32)
       s0 = constant_op.constant([2, 3, 5], dtypes.int32)
       s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
@@ -661,7 +661,7 @@
         sess.run(off)
 
   def testSizeMismatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       cdim = constant_op.constant(1, dtypes.int32)
       s0 = constant_op.constant([2, 3, 5], dtypes.int32)
       s1 = constant_op.constant([2, 7, 10], dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 0dc3c53..a1efecf 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -107,7 +107,7 @@
     self._testCond(true_fn, false_fn, [y])
 
   def testNoInputs(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       pred = array_ops.placeholder(dtypes.bool, name="pred")
 
       def true_fn():
@@ -527,7 +527,7 @@
             }), [5., 0.])
 
   def testSecondDerivative(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       pred = array_ops.placeholder(dtypes.bool, name="pred")
       x = constant_op.constant(3.0, name="x")
 
@@ -801,7 +801,6 @@
 class CondV2ColocationGroupAndDeviceTest(test.TestCase):
 
   def testColocateWithBeforeCond(self):
-    self.skipTest("b/112414483")
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g):
 
@@ -826,7 +825,6 @@
             self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
 
   def testColocateWithInAndOutOfCond(self):
-    self.skipTest("b/112414483")
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g):
 
@@ -874,7 +872,6 @@
         self.assertTrue(len(run_metadata.partition_graphs) >= 2)
 
   def testDeviceBeforeCond(self):
-    self.skipTest("b/112166045")
     with ops.Graph().as_default() as g:
       with self.test_session(graph=g):
         def fn():
@@ -895,11 +892,13 @@
 
   def testDeviceInAndOutOfCond(self):
     with ops.Graph().as_default() as g:
-      with self.test_session(graph=g):
+      with self.test_session(
+          graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):
+
         def fn2():
-          with ops.device("/device:GPU:0"):
+          with ops.device("/device:CPU:1"):
             c = constant_op.constant(3.0)
-            self.assertEqual("/device:GPU:0", c.op.device)
+            self.assertEqual("/device:CPU:1", c.op.device)
             return c
 
         with ops.device("/device:CPU:0"):
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 7570523..262352a 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -42,14 +42,22 @@
     with ops.Graph().as_default():
       q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
     self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       name:'Q' op:'ConditionalAccumulator'
       attr { key: 'dtype' value { type: DT_FLOAT } }
       attr { key: 'shape' value { shape { unknown_rank: true} } }
       attr { key: 'container' value { s: '' } }
       attr { key: 'shared_name' value { s: '' } }
+      attr { key: 'reduction_type' value {s: 'MEAN'} }
       """, q.accumulator_ref.op.node_def)
 
+  def testConstructorWithInvalidArg(self):
+    with ops.Graph().as_default():
+      with self.assertRaises(ValueError):
+        data_flow_ops.ConditionalAccumulator(
+            dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
   def testConstructorWithShape(self):
     with ops.Graph().as_default():
       q = data_flow_ops.ConditionalAccumulator(
@@ -57,7 +65,8 @@
           name="Q",
           shape=tensor_shape.TensorShape([1, 5, 2, 8]))
     self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       name:'Q' op:'ConditionalAccumulator'
       attr { key: 'dtype' value { type: DT_FLOAT } }
       attr { key: 'shape' value { shape { dim {size: 1 }
@@ -67,29 +76,30 @@
       } } }
       attr { key: 'container' value { s: '' } }
       attr { key: 'shared_name' value { s: '' } }
+      attr { key: 'reduction_type' value {s: 'MEAN'} }
       """, q.accumulator_ref.op.node_def)
 
   def testAccumulatorSizeEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
       self.assertEqual(q.num_accumulated().eval(), 0)
 
   def testAccumulatorSetGlobalStep(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       set_global_step_op = q.set_global_step(1)
       set_global_step_op.run()
 
   def testAccumulatorApplyGradFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       accum_op = q.apply_grad((10.0,))
       accum_op.run()
 
   def testDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
 
       for i in range(len(dtypes)):
@@ -106,7 +116,7 @@
         self.assertEqual(sum(elems) / len(elems), result)
 
   def testAccumulatorMultipleAccumulators(self):
-    with self.test_session():
+    with self.cached_session():
       q_f32_0 = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       q_f32_1 = data_flow_ops.ConditionalAccumulator(
@@ -125,7 +135,7 @@
         self.assertEqual(result, i + 10.0)
 
   def testAccumulatorApplyAndTakeGradWithShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=(3, 2))
       elems = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
@@ -156,7 +166,7 @@
       q.apply_grad([[1.0], [2.0], [3.0]])
 
   def testAccumulatorDynamicShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=None)
 
@@ -181,7 +191,7 @@
       self.assertTrue(is_all_equal)
 
   def testAccumulatorWrongDynamicShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=None)
 
@@ -199,7 +209,7 @@
         sess.run(accum_op, feed_dict={x: [[1.0], [2.0], [3.0]]})
 
   def testAccumulatorSizeAfterApplyGrad(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       accum_op = q.apply_grad((10.0,))
@@ -210,7 +220,7 @@
       self.assertEqual(q.num_accumulated().eval(), 2)
 
   def testAccumulatorSizeAfterApplyGradAndTakeGrad(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       accum_op = q.apply_grad((10.0,))
@@ -237,12 +247,11 @@
       extract_t.op.run()
       self.assertEqual(q.num_accumulated().eval(), 0)
 
-  def testAccumulatorTakeGrad(self):
-    with self.test_session():
+  def testAccumulatorTakeGradMean(self):
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       elems = [10.0, 20.0]
-      elems_ave = sum(elems) / len(elems)
 
       accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
       takeg_t = q.take_grad(1)
@@ -251,7 +260,7 @@
         accum_op.run()
 
       val = takeg_t.eval()
-      self.assertEqual(elems_ave, val)
+      self.assertEqual(15.0, val)
 
       accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
       takeg_t = q.take_grad(constant_op.constant(1))
@@ -260,10 +269,45 @@
         accum_op.run()
 
       val = takeg_t.eval()
-      self.assertEqual(elems_ave, val)
+      self.assertEqual(15.0, val)
+
+  def testAccumulatorTakeGradSum(self):
+    with self.test_session():
+      q = data_flow_ops.ConditionalAccumulator(
+          dtypes_lib.float32,
+          name="Q",
+          shape=tensor_shape.TensorShape([1]),
+          reduction_type="SUM")
+      elems = [10.0, 20.0]
+
+      accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+      takeg_t = q.take_grad(1)
+
+      for accum_op in accum_ops:
+        accum_op.run()
+
+      val = takeg_t.eval()
+      self.assertEqual(30.0, val)
+
+      accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+      takeg_t = q.take_grad(constant_op.constant(1))
+
+      for accum_op in accum_ops:
+        accum_op.run()
+
+      val = takeg_t.eval()
+      self.assertEqual(30.0, val)
+
+  def testAccumulatorTakeGradInvalidReductionType(self):
+    with self.assertRaises(ValueError):
+      data_flow_ops.ConditionalAccumulator(
+          dtypes_lib.float32,
+          name="Q",
+          shape=tensor_shape.TensorShape([1]),
+          reduction_type="Invalid")
 
   def testAccumulatorInvalidTakeGrad(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       elems = [10.0, 20.0]
@@ -277,8 +321,8 @@
       with self.assertRaises(errors_impl.InvalidArgumentError):
         takeg_t.eval()
 
-  def testAccumulatorRepeatedTakeGrad(self):
-    with self.test_session():
+  def testAccumulatorRepeatedTakeGradMean(self):
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
 
@@ -304,9 +348,39 @@
       val = takeg_t.eval()
       self.assertEqual(elems_ave + 0.0, val)
 
-  def testAccumulatorIncrementGlobalStep(self):
+  def testAccumulatorRepeatedTakeGradSum(self):
     with self.test_session():
       q = data_flow_ops.ConditionalAccumulator(
+          dtypes_lib.float32,
+          name="Q",
+          shape=tensor_shape.TensorShape([1]),
+          reduction_type="SUM")
+
+      elems = [10.0, 20.0]
+      elems_sum = 30.0
+      accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+      takeg_t = q.take_grad(1)
+
+      for accum_op in accum_ops:
+        accum_op.run()
+
+      val = takeg_t.eval()
+      self.assertEqual(elems_sum, val)
+
+      elems = [20.0, 30.0]
+      elems_sum = 50.0
+      accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+      takeg_t = q.take_grad(1)
+
+      for accum_op in accum_ops:
+        accum_op.run()
+
+      val = takeg_t.eval()
+      self.assertEqual(elems_sum, val)
+
+  def testAccumulatorIncrementGlobalStep(self):
+    with self.cached_session():
+      q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
 
       global_step = variables.Variable(0, name="global_step")
@@ -321,7 +395,7 @@
         inc_global_step.eval()
 
   def testAccumulatorSetGlobalStepPreventsAccumulation(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
 
@@ -342,7 +416,7 @@
                                                      if x >= ls), val)
 
   def testParallelApplyGrad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -367,7 +441,7 @@
       self.assertEqual(val, sum(elems) / len(elems))
 
   def testParallelTakeGrad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       elems = [e for e in range(10)]
@@ -399,7 +473,7 @@
       self.assertItemsEqual(elems, results)
 
   def testAccumulatorApplyAndBlockingTake(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
 
@@ -432,7 +506,7 @@
       sess.run(takeg_op)
 
   def testAccumulatorCancel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.ConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       takeg_t = q.take_grad(1)
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 93f5323..bc24345 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -37,7 +37,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testExample(self):
     """This is a test of the example provided in pydoc."""
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual([
           [0, 0, 0, 0, 0],
           [0, 0, 1, 0, 0],
@@ -49,7 +49,7 @@
 
   def _testConfMatrix(self, labels, predictions, truth, weights=None,
                       num_classes=None):
-    with self.test_session():
+    with self.cached_session():
       dtype = predictions.dtype
       ans = confusion_matrix.confusion_matrix(
           labels, predictions, dtype=dtype, weights=weights,
@@ -78,7 +78,7 @@
     self._testBasic(dtype=np.int64)
 
   def _testConfMatrixOnTensors(self, tf_dtype, np_dtype):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       m_neg = array_ops.placeholder(dtype=dtypes.float32)
       m_pos = array_ops.placeholder(dtype=dtypes.float32)
       s = array_ops.placeholder(dtype=dtypes.float32)
@@ -229,7 +229,7 @@
   def testOutputIsInt32(self):
     labels = np.arange(2)
     predictions = np.arange(2)
-    with self.test_session():
+    with self.cached_session():
       cm = confusion_matrix.confusion_matrix(
           labels, predictions, dtype=dtypes.int32)
       tf_cm = cm.eval()
@@ -238,7 +238,7 @@
   def testOutputIsInt64(self):
     labels = np.arange(2)
     predictions = np.arange(2)
-    with self.test_session():
+    with self.cached_session():
       cm = confusion_matrix.confusion_matrix(
           labels, predictions, dtype=dtypes.int64)
       tf_cm = cm.eval()
@@ -260,7 +260,7 @@
         confusion_matrix.remove_squeezable_dimensions(
             labels_placeholder, predictions_placeholder))
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(label_values, static_labels.eval())
       self.assertAllEqual(prediction_values, static_predictions.eval())
       feed_dict = {
@@ -285,7 +285,7 @@
         confusion_matrix.remove_squeezable_dimensions(
             labels_placeholder, predictions_placeholder))
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(label_values, static_labels.eval())
       self.assertAllEqual(prediction_values, static_predictions.eval())
       feed_dict = {
@@ -310,7 +310,7 @@
         confusion_matrix.remove_squeezable_dimensions(
             labels_placeholder, predictions_placeholder, expected_rank_diff=0))
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(label_values, static_labels.eval())
       self.assertAllEqual(prediction_values, static_predictions.eval())
       feed_dict = {
@@ -336,7 +336,7 @@
             labels_placeholder, predictions_placeholder))
 
     expected_label_values = np.reshape(label_values, newshape=(2, 3))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_label_values, static_labels.eval())
       self.assertAllEqual(prediction_values, static_predictions.eval())
       feed_dict = {
@@ -362,7 +362,7 @@
             labels_placeholder, predictions_placeholder, expected_rank_diff=1))
 
     expected_label_values = np.reshape(label_values, newshape=(2, 3))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_label_values, static_labels.eval())
       self.assertAllEqual(prediction_values, static_predictions.eval())
       feed_dict = {
@@ -388,7 +388,7 @@
             labels_placeholder, predictions_placeholder))
 
     expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(label_values, static_labels.eval())
       self.assertAllEqual(expected_prediction_values, static_predictions.eval())
       feed_dict = {
@@ -415,7 +415,7 @@
             labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
 
     expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(label_values, static_labels.eval())
       self.assertAllEqual(expected_prediction_values, static_predictions.eval())
       feed_dict = {
@@ -441,7 +441,7 @@
         confusion_matrix.remove_squeezable_dimensions(
             labels_placeholder, predictions_placeholder))
 
-    with self.test_session():
+    with self.cached_session():
       feed_dict = {
           labels_placeholder: label_values,
           predictions_placeholder: prediction_values
@@ -466,7 +466,7 @@
         confusion_matrix.remove_squeezable_dimensions(
             labels_placeholder, predictions_placeholder))
 
-    with self.test_session():
+    with self.cached_session():
       feed_dict = {
           labels_placeholder: label_values,
           predictions_placeholder: prediction_values
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 107ee37..d1e4e54 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -162,18 +162,18 @@
       logging_const_op.run()
 
   def testStringWithNulls(self):
-    with self.test_session():
+    with self.cached_session():
       val = ops.convert_to_tensor(b"\0\0\0\0").eval()
     self.assertEqual(len(val), 4)
     self.assertEqual(val, b"\0\0\0\0")
 
-    with self.test_session():
+    with self.cached_session():
       val = ops.convert_to_tensor(b"xx\0xx").eval()
     self.assertEqual(len(val), 5)
     self.assertAllEqual(val, b"xx\0xx")
     nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]]
 
-    with self.test_session():
+    with self.cached_session():
       val = ops.convert_to_tensor(nested).eval()
     # NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
     #   numpy array, which loses the null terminators.
@@ -279,7 +279,7 @@
     self.assertTrue(isinstance(x, ops.Tensor))
 
   def testAsTensorForShapeInput(self):
-    with self.test_session():
+    with self.cached_session():
       x = ops.convert_to_tensor(tensor_shape.TensorShape([]))
       self.assertEqual(dtypes_lib.int32, x.dtype)
       self.assertAllEqual([], x.eval())
@@ -331,7 +331,7 @@
           tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.float32)
 
   def testAsTensorForDimensionInput(self):
-    with self.test_session():
+    with self.cached_session():
       x = ops.convert_to_tensor(tensor_shape.TensorShape([1, 2, 3])[1])
       self.assertEqual(dtypes_lib.int32, x.dtype)
       self.assertAllEqual(2, x.eval())
@@ -367,7 +367,7 @@
 class ZerosTest(test.TestCase):
 
   def _Zeros(self, shape):
-    with self.test_session():
+    with self.cached_session():
       ret = array_ops.zeros(shape)
       self.assertEqual(shape, ret.get_shape())
       return ret.eval()
@@ -379,13 +379,13 @@
   def testScalar(self):
     self.assertEqual(0, self._Zeros([]))
     self.assertEqual(0, self._Zeros(()))
-    with self.test_session():
+    with self.cached_session():
       scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32))
       self.assertEqual(0, scalar.eval())
 
   def testDynamicSizes(self):
     np_ans = np.array([[0] * 3] * 2)
-    with self.test_session():
+    with self.cached_session():
       # Creates a tensor of 2 x 3.
       d = array_ops.fill([2, 3], 12., name="fill")
       # Constructs a tensor of zeros of the same dimensions as "d".
@@ -396,7 +396,7 @@
     self.assertShapeEqual(np_ans, z)
 
   def testDtype(self):
-    with self.test_session():
+    with self.cached_session():
       d = array_ops.fill([2, 3], 12., name="fill")
       self.assertEqual(d.get_shape(), [2, 3])
       # Test default type for both constant size and dynamic size
@@ -489,7 +489,7 @@
 
   def testZerosLikeDtype(self):
     # Make sure zeros_like works even for dtypes that cannot be cast between
-    with self.test_session():
+    with self.cached_session():
       shape = (3, 5)
       dtypes = np.float32, np.complex64
       for in_type in dtypes:
@@ -533,7 +533,7 @@
 class OnesTest(test.TestCase):
 
   def _Ones(self, shape):
-    with self.test_session():
+    with self.cached_session():
       ret = array_ops.ones(shape)
       self.assertEqual(shape, ret.get_shape())
       return ret.eval()
@@ -544,13 +544,13 @@
   def testScalar(self):
     self.assertEqual(1, self._Ones([]))
     self.assertEqual(1, self._Ones(()))
-    with self.test_session():
+    with self.cached_session():
       scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32))
       self.assertEqual(1, scalar.eval())
 
   def testDynamicSizes(self):
     np_ans = np.array([[1] * 3] * 2)
-    with self.test_session():
+    with self.cached_session():
       # Creates a tensor of 2 x 3.
       d = array_ops.fill([2, 3], 12., name="fill")
       # Constructs a tensor of ones of the same dimensions as "d".
@@ -561,7 +561,7 @@
     self.assertShapeEqual(np_ans, z)
 
   def testAutoPack(self):
-    with self.test_session():
+    with self.cached_session():
       h = array_ops.placeholder(dtypes_lib.int32, shape=[])
       w = array_ops.placeholder(dtypes_lib.int32, shape=[])
       z = array_ops.ones([h, w])
@@ -569,7 +569,7 @@
     self.assertAllEqual(out, np.array([[1] * 16] * 4))
 
   def testDtype(self):
-    with self.test_session():
+    with self.cached_session():
       d = array_ops.fill([2, 3], 12., name="fill")
       self.assertEqual(d.get_shape(), [2, 3])
       # Test default type for both constant size and dynamic size
@@ -606,7 +606,7 @@
         dtypes_lib.complex128
     ]:
       numpy_dtype = dtype.as_numpy_dtype
-      with self.test_session():
+      with self.cached_session():
         # Creates a tensor of non-zero values with shape 2 x 3.
         d = constant_op.constant(
             np.ones(
@@ -672,7 +672,7 @@
     self.assertAllEqual(np_ans, tf_ans)
 
   def testFillNegative(self):
-    with self.test_session():
+    with self.cached_session():
       for shape in (-1,), (2, -1), (-1, 2), (-2), (-3):
         with self.assertRaises(ValueError):
           array_ops.fill(shape, 7)
@@ -703,7 +703,7 @@
     self.assertEqual([None, 17], f.get_shape().as_list())
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       in_v = constant_op.constant(5.0)
       out_shape = [3, 2]
       out_filled = array_ops.fill(out_shape, in_v)
@@ -715,7 +715,7 @@
 class PlaceholderTest(test.TestCase):
 
   def testDtype(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
       p_identity = array_ops.identity(p)
       feed_array = np.random.rand(10, 10)
@@ -727,7 +727,7 @@
         p_identity.eval()
 
   def testShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
       p_identity = array_ops.identity(p)
       feed_array = np.random.rand(10, 10)
@@ -744,7 +744,7 @@
         p_identity.eval(feed_dict={p: feed_array[:5, :5]})
 
   def testUnknownShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=None, name="p")
       p_identity = array_ops.identity(p)
       # can feed anything
@@ -756,13 +756,13 @@
           p_identity.eval(feed_dict={p: feed_array}), feed_array)
 
   def testScalarShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=[], name="p")
       p_identity = array_ops.identity(p)
       self.assertAllClose(p_identity.eval(feed_dict={p: 5}), 5)
 
   def testPartialShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
       p_identity = array_ops.identity(p)
       feed_array = np.random.rand(10, 3)
@@ -774,7 +774,7 @@
         p_identity.eval(feed_dict={p: feed_array[:5, :2]})
 
   def testPartialShapeWhenNotFed(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
       p_identity = array_ops.identity(p)
 
@@ -784,7 +784,7 @@
         p_identity.eval()
 
   def testControlDependency(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes_lib.int32, shape=[], name="p")
       with ops.control_dependencies([p]):
         c = constant_op.constant(5, dtypes_lib.int32)
@@ -872,7 +872,7 @@
 """
     gdef = graph_pb2.GraphDef()
     text_format.Merge(graph, gdef)
-    with self.test_session():
+    with self.cached_session():
       p, ret = importer.import_graph_def(
           gdef, return_elements=["Placeholder:0", "add:0"])
 
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index eac97af..ebeabcf 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -129,7 +129,7 @@
 class ControlFlowTest(test.TestCase):
 
   def testRefIdentity(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(7)
 
       v = control_flow_ops._Identity(v)
@@ -141,7 +141,7 @@
       self.assertEqual(9, v2.eval())
 
   def testRefEnter(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(7)
 
       enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
@@ -154,7 +154,7 @@
       self.assertEqual(9, v3.eval())
 
   def testRefSwitch(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(7)
 
       p = constant_op.constant(True)
@@ -164,7 +164,7 @@
       self.assertEqual(9, v2.eval())
 
   def testEnterMulExit(self):
-    with self.test_session():
+    with self.cached_session():
       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
       enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
       five = constant_op.constant(5)
@@ -176,7 +176,7 @@
     self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
 
   def testEnterShapePropagation(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
 
       # If is_constant=True, the shape information should be propagated.
@@ -190,7 +190,7 @@
       self.assertEqual(enter_v_non_constant.shape, None)
 
   def testSwitchMergeIndexedSlices(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([1, 2, 3, 4, 5, 6])
       indices = constant_op.constant([0, 2, 4, 6, 8, 10])
       data = ops.IndexedSlices(values, indices)
@@ -204,7 +204,7 @@
     self.assertAllEqual(np.arange(0, 12, 2), ind)
 
   def testSwitchDeadBranch(self):
-    with self.test_session():
+    with self.cached_session():
       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
       ports = ops.convert_to_tensor(True, name="ports")
       switch_op = control_flow_ops.switch(data, ports)
@@ -216,7 +216,7 @@
         dead_branch.eval()
 
   def testSwitchMergeLess(self):
-    with self.test_session():
+    with self.cached_session():
       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
       zero = ops.convert_to_tensor(0)
       one = ops.convert_to_tensor(1)
@@ -228,7 +228,7 @@
     self.assertAllEqual(np.arange(1, 7), result)
 
   def testSwitchMergeAddIdentity(self):
-    with self.test_session():
+    with self.cached_session():
       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
       ports = ops.convert_to_tensor(False, name="ports")
       switch_op = control_flow_ops.switch(data, ports)
@@ -241,7 +241,7 @@
     self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
 
   def testSwitchMergeAddMul(self):
-    with self.test_session():
+    with self.cached_session():
       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
       ports = ops.convert_to_tensor(True, name="ports")
       switch_op = control_flow_ops.switch(data, ports)
@@ -255,7 +255,7 @@
     self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
 
   def testLoop_false(self):
-    with self.test_session():
+    with self.cached_session():
       false = ops.convert_to_tensor(False)
       n = constant_op.constant(10)
 
@@ -272,7 +272,7 @@
     self.assertAllEqual(10, result)
 
   def testLoop_1(self):
-    with self.test_session():
+    with self.cached_session():
       zero = constant_op.constant(0)
       one = constant_op.constant(1)
       n = constant_op.constant(10)
@@ -298,7 +298,7 @@
     self.assertAllEqual(10, result)
 
   def testLoop_2(self):
-    with self.test_session():
+    with self.cached_session():
       zero = constant_op.constant(0)
       one = constant_op.constant(1)
       n = constant_op.constant(10)
@@ -324,7 +324,7 @@
     self.assertAllEqual(10, result)
 
   def testDifferentFrame(self):
-    with self.test_session():
+    with self.cached_session():
       data = array_ops.placeholder(dtypes.float32, shape=[])
       enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
       enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
@@ -333,7 +333,7 @@
         res.eval(feed_dict={data: 1.0})
 
   def testCondBool(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113296297")
 
     values = constant_op.constant(10)
@@ -352,7 +352,7 @@
     self.assertAllEqual([None], grad)
 
   def testFetchable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       control_flow_ops.cond(
           constant_op.constant(True), lambda: x + 2, lambda: x + 0)
@@ -367,7 +367,7 @@
               sess.run(t, feed_dict={x: 3})
 
   def testFeedable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(2)
       i0 = constant_op.constant(0)
       r = control_flow_ops.while_loop(lambda i: i < 1000,
@@ -384,10 +384,10 @@
               sess.run(r, feed_dict={t: 3})
 
   def testCondIndexedSlices(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113296180")
 
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant(10)
       indices = constant_op.constant(0)
       x = ops.IndexedSlices(values, indices)
@@ -402,10 +402,10 @@
     self.assertAllEqual(0, ind)
 
   def testCondSparseTensor(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113296161 (SparseTensors)")
 
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
       indices = constant_op.constant(
           [[0], [3]], dtype=dtypes.int64, name="indices")
@@ -422,10 +422,10 @@
       self.assertAllEqual(r.values.get_shape(), (2,))
 
   def testCondResource(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       rv = resource_variable_ops.ResourceVariable(True)
       variables.global_variables_initializer().run()
       t = ops.convert_to_tensor(1.0)
@@ -438,10 +438,10 @@
       self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
 
   def testCondIndexedSlicesDifferentTypes(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113293074")
 
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant(10)
       i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
       i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
@@ -484,17 +484,17 @@
     self.assertAllEqual(11, result)
 
   def testCond_1(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
     self._testCond_1(use_gpu=False)
     self._testCond_1(use_gpu=True)
 
   def testCond_2(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(10)
       r = control_flow_ops.cond(
           math_ops.less(1, 0), lambda: math_ops.add(x, 1),
@@ -503,10 +503,10 @@
     self.assertAllEqual(9, result)
 
   def testCond_3(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(10)
       pred = math_ops.less(1, 2)
       fn1 = lambda: math_ops.add(x, 1)
@@ -518,10 +518,10 @@
     self.assertAllEqual(12, result)
 
   def testCond_4(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113324949 (ref vars)")
 
-    with self.test_session():
+    with self.cached_session():
       v1 = variables.Variable(7)
       v2 = variables.Variable(7)
       v3 = variables.Variable(7)
@@ -542,7 +542,7 @@
       self.assertAllEqual(7, v3.eval())
 
   def testCond_5(self):
-    with self.test_session():
+    with self.cached_session():
       alive = constant_op.constant(True, name="alive")
       count = constant_op.constant(0, name="count")
 
@@ -556,10 +556,10 @@
       self.assertAllEqual(4, count.eval())
 
   def testCond_6(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       v1 = variables.Variable([7])
 
       age = constant_op.constant(3)
@@ -573,7 +573,7 @@
       self.assertAllEqual(np.array([7]), result)
 
   def testCond_7(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = constant_op.constant(10)
       y = constant_op.constant(200)
       pred = math_ops.less(1, 2)
@@ -583,10 +583,10 @@
       self.assertAllEqual([11, 12], sess.run(r))
 
   def testCondRef(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       x = gen_state_ops.variable(
           shape=[1],
           dtype=dtypes.float32,
@@ -599,10 +599,10 @@
       self.assertAllEqual([2.0], r.eval())
 
   def testCondWithControl(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/79881896")
 
-    with self.test_session() as sess:
+    with self.cached_session():
       control_holder = array_ops.placeholder(dtypes.float32, shape=())
       a = constant_op.constant(3)
 
@@ -617,7 +617,7 @@
       self.assertEqual(5, r.eval())
 
   def testUninitializedRefIdentity(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = gen_state_ops.variable(
           shape=[1],
           dtype=dtypes.float32,
@@ -641,7 +641,7 @@
       self.assertAllEqual([1.0], sess.run(merged_op.output))
 
   def testCondSwitchIdentity(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/112477618 (Operation returned from cond)")
 
     # Make sure the recv identity is not removed by optimization.
@@ -658,7 +658,7 @@
       sess.run(r)
 
   def testCondRecvIdentity(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/112477618 (Operation returned from cond)")
 
     # Make sure the switch identity is not removed by optimization.
@@ -677,7 +677,7 @@
       sess.run(r)
 
   def testCondGrad_1(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113346829 (gpu failure)")
 
     graph = ops.Graph()
@@ -689,11 +689,11 @@
       r = control_flow_ops.cond(pred, fn1, fn2)
 
       grad = gradients_impl.gradients(r, [x])[0]
-      with self.test_session():
+      with self.cached_session():
         self.assertAllEqual(1.0, grad.eval())
 
   def testCondGrad_2(self):
-    with self.test_session():
+    with self.cached_session():
       c = array_ops.placeholder(dtypes.int32, shape=[])
       x = constant_op.constant(10.0)
       pred = math_ops.less(c, 2)
@@ -706,10 +706,10 @@
       self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
 
   def testCondGrad_3(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/110550782 (gradient w.r.t external variable)")
 
-    with self.test_session():
+    with self.cached_session():
       c = array_ops.placeholder(dtypes.int32, shape=[])
       ox = constant_op.constant(10.0)
       pred = math_ops.less(c, 2)
@@ -726,7 +726,7 @@
       self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
 
   def testNestedCond_Simple(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(0., name="X")
       y = control_flow_ops.cond(
           constant_op.constant(True), lambda: x,
@@ -741,10 +741,10 @@
       self.assertEqual(1.0, result.eval())
 
   def testCondGrad_Gather(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113327884")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v1 = variables.Variable([1.0, 42.0])
       c = array_ops.placeholder(dtypes.int32, shape=[])
       pred = math_ops.less(c, 2)
@@ -768,7 +768,7 @@
 
   # Microbenchmark: 256,000 iterations/s.
   def testWhile_1(self):
-    with self.test_session():
+    with self.cached_session():
       n = constant_op.constant(0)
       c = lambda x: math_ops.less(x, 10000)
       b = lambda x: math_ops.add(x, 1)
@@ -776,7 +776,7 @@
       self.assertEqual(10000, r.eval())
 
   def testWhileExternalControlDependencies(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(0.0)
       v.initializer.run()
       increment = v.assign_add(1.0)
@@ -791,7 +791,7 @@
       self.assertAllEqual(v.eval(), 1.0)
 
   def testWhileExternalControlDependenciesNoInput(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(0.0)
       v.initializer.run()
       increment = v.assign_add(1.0)
@@ -806,7 +806,7 @@
       self.assertAllEqual(v.eval(), 1.0)
 
   def testWhileWithRefs_1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = variables.Variable(0)._ref()  # pylint: disable=protected-access
       i = constant_op.constant(0)
       c = lambda i, x: math_ops.less(i, 100)
@@ -830,19 +830,19 @@
     self.assertEqual(0, value_x)
 
   def testWhile_2(self):
-    with self.test_session():
+    with self.cached_session():
       s = constant_op.constant(0)
       r = isum(s)
       self.assertAllEqual(45, r.eval())
 
   def testWhileWithMaximumIterations(self):
-    with self.test_session():
+    with self.cached_session():
       s = constant_op.constant([1, 2, 3, 4, 5])
       r = isum(s, maximum_iterations=3)
       self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
 
   def testWhileWithMaximumIterationsAndSingleArgument(self):
-    with self.test_session():
+    with self.cached_session():
       r = control_flow_ops.while_loop(
           lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
       self.assertEqual(1, r.eval())
@@ -916,7 +916,7 @@
       _ = gradients_impl.gradients(loop_with_maxiter, v)
 
   def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294340 (enable while_v2)")
 
     v = constant_op.constant(1.0)
@@ -1019,7 +1019,7 @@
   # Have more than 10 parallel iterations and hence exercise k-bound
   # most of the time.
   def testWhile_3(self):
-    with self.test_session():
+    with self.cached_session():
 
       def compute(i, m, c, o):
         m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
@@ -1039,7 +1039,7 @@
     self.assertAllEqual(10100, result)
 
   def testWhile_4(self):
-    with self.test_session():
+    with self.cached_session():
 
       def compute(i, m, c, o):
         m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
@@ -1060,7 +1060,7 @@
     self.assertAllEqual(42, result)
 
   def testWhile_5(self):
-    with self.test_session():
+    with self.cached_session():
 
       def compute(i, c, o):
         c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
@@ -1088,7 +1088,7 @@
         trace_level=config_pb2.RunOptions.FULL_TRACE)
     run_metadata = config_pb2.RunMetadata()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with ops.device("/cpu:0"):
         c = constant_op.constant(2)
         i0 = constant_op.constant(0)
@@ -1134,7 +1134,7 @@
     self._testWhile_Gpu_1(use_gpu=True)
 
   def testWhileShape(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0)
       m = array_ops.ones([2, 2])
       c = lambda i, j: math_ops.less(i, 2)
@@ -1151,7 +1151,7 @@
       self.assertAllEqual(np.ones((8, 8)), r.eval())
 
   def testWhileWithNonTensorInput_Scalar(self):
-    with self.test_session():
+    with self.cached_session():
       n = 0
       c = lambda x: x < 10000
       b = lambda x: x + 1
@@ -1159,7 +1159,7 @@
       self.assertEqual(10000, r.eval())
 
   def testWhileWithNonTensorInput_Vector(self):
-    with self.test_session():
+    with self.cached_session():
       n = np.array([0])  # Note, [0] would not work here; that is a list
       c = lambda x: x[0] < 10000
       b = lambda x: array_ops.stack([x[0] + 1])
@@ -1167,7 +1167,7 @@
       self.assertEqual([10000], r.eval())
 
   def testWhileShapeInference(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0)
       m = array_ops.ones([2, 2])
       c = lambda i, j: math_ops.less(i, 2)
@@ -1192,7 +1192,7 @@
         r = control_flow_ops.while_loop(c, b, [i, m])
 
   def testWhileShapeInferenceSparseTensor(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
       indices = constant_op.constant(
           [[0], [3]], dtype=dtypes.int64, name="indices")
@@ -1223,7 +1223,7 @@
             [i.get_shape(), tensor_shape.TensorShape([5])])
 
   def testWhileShapeInferenceIndexedSlices(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
       indices = constant_op.constant([0, 3], name="indices")
       shape = constant_op.constant([10, 2], name="dense_shape")
@@ -1313,7 +1313,7 @@
     self._testNestedWhile_2(use_gpu=True)
 
   def testWhileWithControl_1(self):
-    with self.test_session():
+    with self.cached_session():
       n = constant_op.constant(0)
       r = constant_op.constant(0)
       condition = lambda n_, r_: math_ops.less(n_, 10)
@@ -1329,7 +1329,7 @@
       self.assertAllEqual(12, res[1].eval())
 
   def testWhileWithControl_2(self):
-    with self.test_session():
+    with self.cached_session():
       r = constant_op.constant(0)
       condition = lambda r_: math_ops.less(r_, 10)
 
@@ -1343,7 +1343,7 @@
       self.assertAllEqual(12, res.eval())
 
   def testWhileWithControl_3(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = array_ops.placeholder(dtypes.bool)
       c = constant_op.constant(1)
       x0 = constant_op.constant(0)
@@ -1352,7 +1352,7 @@
       self.assertEqual(10, sess.run(r, {b: True}))
 
   def testWhileWithControl_4(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = array_ops.placeholder(dtypes.bool)
       c = constant_op.constant(1)
       x0 = constant_op.constant(0)
@@ -1362,7 +1362,7 @@
       self.assertEqual(10, sess.run(r, {b: True}))
 
   def testWhileWithControl_5(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       b = array_ops.placeholder(dtypes.bool)
       c = constant_op.constant(1)
       x0 = constant_op.constant(0)
@@ -1375,12 +1375,12 @@
       self.assertEqual(10, sess.run(r, {b: True}))
 
   def testWhileCondWithControl(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
     # Ensure that no control edges by an outer control dependency context are
     # added to nodes inside cond/while contexts.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const_true = lambda: constant_op.constant(True)
       const_false = lambda: constant_op.constant(False)
       cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
@@ -1392,10 +1392,10 @@
       self.assertEqual(0, sess.run(loop))
 
   def testWhileCondWithControl_1(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113324949 (ref vars)")
 
-    with self.test_session():
+    with self.cached_session():
       v = variable_scope.get_variable(
           "v", [], initializer=init_ops.constant_initializer(2))
       i0 = constant_op.constant(0)
@@ -1417,10 +1417,10 @@
       self.assertAllClose(65536.0, v.eval())
 
   def testWhileCondExitControl(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294340 (enable while_v2)")
 
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(1)
 
       def false_branch():
@@ -1443,10 +1443,10 @@
       self.assertEqual(99, v.eval())
 
   def testCondWhile_1(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(0, name="n")
       c = lambda x: math_ops.less(x, 10)
       b = lambda x: math_ops.add(x, 1)
@@ -1456,10 +1456,10 @@
       self.assertAllEqual(10, r.eval())
 
   def testCondWhile_2(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(0)
       c = lambda x: math_ops.less(x, 10)
       b = lambda x: math_ops.add(x, 1)
@@ -1469,7 +1469,7 @@
       self.assertAllEqual(10, r.eval())
 
   def _testCondWhile_3(self, use_gpu):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294340 (enable while_v2)")
 
     with self.test_session(use_gpu=use_gpu) as sess:
@@ -1498,10 +1498,10 @@
     self._testCondWhile_3(use_gpu=True)
 
   def testWhileCond_1(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
-    with self.test_session():
+    with self.cached_session():
       i = ops.convert_to_tensor(0, name="i")
       n = ops.convert_to_tensor(10, name="n")
       one = ops.convert_to_tensor(1, name="one")
@@ -1516,10 +1516,10 @@
       self.assertAllEqual(10, r.eval())
 
   def testWhileCond_2(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(0, name="n")
       c = lambda x: math_ops.less(x, 10)
       b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
@@ -1527,10 +1527,10 @@
       self.assertAllEqual(10, r.eval())
 
   def testWhileCond_3(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(0)
       c = lambda x: math_ops.less(x, 10)
       # pylint: disable=undefined-variable
@@ -1544,7 +1544,7 @@
 
   # NOTE: It is ok to have parallel_iterations > 1
   def testWhileUpdateVariable_1(self):
-    with self.test_session():
+    with self.cached_session():
       select = variables.Variable([3.0, 4.0, 5.0])
       n = constant_op.constant(0)
 
@@ -1566,7 +1566,7 @@
       self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
 
   def testWhileUpdateVariable_2(self):
-    with self.test_session():
+    with self.cached_session():
       select1 = variables.Variable([3.0, 4.0, 5.0])
       select2 = variables.Variable([3.0, 4.0, 5.0])
       n = constant_op.constant(0)
@@ -1592,7 +1592,7 @@
       self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
 
   def testWhileUpdateVariable_3(self):
-    with self.test_session():
+    with self.cached_session():
       select = variables.Variable([3.0, 4.0, 5.0])
       n = constant_op.constant(0)
 
@@ -1614,7 +1614,7 @@
 
   # b/24814703
   def testWhileUpdateVariable_4(self):
-    with self.test_session():
+    with self.cached_session():
       var_a = variables.Variable(0, name="a")
       var_b = variables.Variable(0, name="b")
       variables.global_variables_initializer().run()
@@ -1642,7 +1642,7 @@
 
   # b/24736492
   def testWhileUpdateVariable_5(self):
-    with self.test_session():
+    with self.cached_session():
       # Create some variables.
       var_a = variables.Variable(0, name="a")
       var_b = variables.Variable(0, name="b")
@@ -1672,7 +1672,7 @@
 
   # b/24814668
   def testWhileUpdateVariable_6(self):
-    with self.test_session():
+    with self.cached_session():
       # Create some variables.
       var_a = variables.Variable(0, name="a")
       var_b = variables.Variable(0, name="b")
@@ -1701,7 +1701,7 @@
       self.assertEqual(10, var_a.eval())
 
   def testWhileQueue_1(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
       i = constant_op.constant(0)
 
@@ -1719,7 +1719,7 @@
         self.assertEqual([i], q.dequeue().eval())
 
   def testWhileStack_1(self):
-    with self.test_session():
+    with self.cached_session():
       s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
       i = constant_op.constant(0)
 
@@ -1753,7 +1753,7 @@
 
   def _testWhileGrad_ColocateGradients(self, colocate):
     gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
-    ) else "/device:GPU:0"
+    ) else "/device:CPU:0"
 
     graph = ops.Graph()
     with graph.as_default():
@@ -1791,7 +1791,7 @@
     self._testWhileGrad_ColocateGradients(colocate=True)
 
   def testWhileGrad_Square(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(2.0, name="v")
       c = lambda v: math_ops.less(v, 100.0)
       b = math_ops.square
@@ -1802,7 +1802,7 @@
       self.assertAllClose(1024.0, r.eval())
 
   def testWhileGrad_Shape(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtypes.float32, shape=[None])
       v = constant_op.constant([2.0], name="v")
       n = constant_op.constant(0, name="n")
@@ -1819,7 +1819,7 @@
       self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
 
   def testWhileGrad_BaseShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32, [None])
       v0 = constant_op.constant([2.0, 2.0], name="v")
       c = lambda v: constant_op.constant(False)
@@ -1831,7 +1831,7 @@
       self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
 
   def testWhileGrad_MultipleUses(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(2.0, name="v")
       c = lambda v: math_ops.less(v, 100.0)
       b = math_ops.square
@@ -1842,7 +1842,7 @@
       self.assertEqual(524288.0, r.eval())
 
   def testWhileGrad_LoopAdd(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(2.0, name="v")
       c = lambda v: math_ops.less(v, 100.0)
       b = math_ops.square
@@ -1872,7 +1872,7 @@
     self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
 
   def _testNestedWhileCondWhileGrad(self, use_gpu):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
     with self.test_session(use_gpu=use_gpu):
@@ -1901,7 +1901,7 @@
     self._testNestedWhileCondWhileGrad(use_gpu=True)
 
   def testWhileGrad_Variable(self):
-    with self.test_session():
+    with self.cached_session():
       a = variables.Variable(3.0)
       v = constant_op.constant(2.0, name="v")
       c = lambda v: math_ops.less(v, 100.0)
@@ -1913,10 +1913,10 @@
       self.assertAllClose(216.0, r[0].eval())
 
   def testWhileGradInCond(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/110550782 (gradient w.r.t external variable)")
 
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(1.0, name="n")
       x = array_ops.placeholder(dtypes.float32, shape=None)
       c = lambda n: math_ops.less(n, 10.0)
@@ -1931,7 +1931,7 @@
       self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
 
   def testGradInWhileWrtInitialLoopVal(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
       y = x + 1
 
@@ -1948,7 +1948,7 @@
         control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
 
   def testWhileGradInWhile(self):
-    with self.test_session():
+    with self.cached_session():
       n = ops.convert_to_tensor(1.0, name="n")
       x = array_ops.placeholder(dtypes.float32, shape=None)
       c = lambda n: math_ops.less(n, 10.0)
@@ -1964,7 +1964,7 @@
       self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
 
   def testCondGradInNestedWhiles(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113346829 (gpu failure)")
 
     def outer_body(i, x):
@@ -1978,13 +1978,13 @@
 
     i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       i_val, x_val = sess.run([i, x])
       self.assertEqual(i_val, 3)
       self.assertAllClose(x_val, 1.0)
 
   def testWhile_NestedInput(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       named = collections.namedtuple("named", ("a", "b"))
       loop_vars = [
           named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2011,7 +2011,7 @@
                        sess.run(r_flattened))
 
   def testWhile_NestedBadArityFails(self):
-    with self.test_session():
+    with self.cached_session():
       named = collections.namedtuple("named", ("a", "b"))
       loop_vars = [
           named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2027,7 +2027,7 @@
         control_flow_ops.while_loop(c, b, loop_vars)
 
   def testWhileGrad_ys_xs(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(3.0, name="x")
       y = constant_op.constant(2.0, name="y")
 
@@ -2050,7 +2050,7 @@
       self.assertAllClose(120.0, r[0].eval())
 
   def testWhileGrad_Dependency(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0, name="i")
       x = constant_op.constant(2.0, name="x")
 
@@ -2069,7 +2069,7 @@
       self.assertAllClose(1024.0, r[0].eval())
 
   def testWhileGrad_NoGradient(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(2.0, name="v")
       c = lambda v: math_ops.less(v, 100.0)
       b = math_ops.square
@@ -2079,7 +2079,7 @@
       self.assertAllClose(1.0, r[0].eval())
 
   def testWhileGrad_NoDependency(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variable = variables.Variable(array_ops.ones([2, 3]))
       duration = array_ops.zeros([], dtype=dtypes.int32)
 
@@ -2099,7 +2099,7 @@
       self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
 
   def testWhileGrad_Const(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c0 = constant_op.constant(0.0, name="c0")
       c1 = constant_op.constant(1.0, name="c1")
       duration = constant_op.constant(0, name="t")
@@ -2118,7 +2118,7 @@
       self.assertAllClose(0.0, sess.run(grad[0]))
 
   def testWhileGrad_SerialTwoLoops(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0, name="i")
       x = constant_op.constant(2.0, name="x")
 
@@ -2136,7 +2136,7 @@
       self.assertAllClose(1024.0, r[0].eval())
 
   def testWhileGrad_ParallelTwoLoops(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0, name="i")
       x = constant_op.constant(2.0, name="x")
 
@@ -2155,7 +2155,7 @@
       self.assertAllClose(64.0, r[0].eval())
 
   def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
-    with self.test_session():
+    with self.cached_session():
       i = constant_op.constant(0, name="i")
       x = constant_op.constant(1.0, name="x")
       y = constant_op.constant(1.0, name="y")
@@ -2196,7 +2196,7 @@
     self._testNestedWhileGrad_Simple(use_gpu=True)
 
   def testNestedWhileGrad_SerialInner(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(1.0)
 
       def inner_loop1(s):
@@ -2219,7 +2219,7 @@
       self.assertAllClose(256.0, r.eval())
 
   def testNestedWhileGrad_ParallelInner(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant(1.0)
 
       def inner_loop1(s):
@@ -2244,7 +2244,7 @@
   def testNestedWhileGrad_ParallelIterations(self):
     # Make sure the stack pushes and pops of an inner loop are executed in
     # the sequential order of the iterations of its outer loop.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       def inner_loop(t):
         fn = lambda n: n + math_ops.square(var)
@@ -2280,14 +2280,14 @@
       self.assertAllClose(1024.0, r.eval())
 
   def testWhileCondGrad_Simple(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113294377 (unknown shape)")
 
     self._testWhileCondGrad_Simple(use_gpu=False)
     self._testWhileCondGrad_Simple(use_gpu=True)
 
   def testWhileCondGrad_UnknownShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = array_ops.placeholder(dtypes.float32)
       n = ops.convert_to_tensor(100.0, name="n")
       one = ops.convert_to_tensor(1.0, name="one")
@@ -2304,7 +2304,7 @@
       self.assertAllClose(1024.0, r)
 
   def testWhileGrad_Concat(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = variable_scope.get_variable("x", initializer=[[1., 2.]])
       i0 = constant_op.constant(0)
       h0 = array_ops.zeros([0, 2])
@@ -2327,7 +2327,7 @@
       self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
 
   def testWhileWithRefsWithGradients_1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = variables.Variable(0.)._ref()  # pylint: disable=protected-access
       i = constant_op.constant(0)
       c = lambda i, x: math_ops.less(i, 10)
@@ -2355,7 +2355,7 @@
     self.assertEqual(73, value_x_grad)
 
   def testWhileGrad_IndexedSlices(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
       indices = constant_op.constant([0, 3], name="indices")
       shape = constant_op.constant([10], name="dense_shape")
@@ -2376,7 +2376,7 @@
       self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
 
   def testWhileGrad_SparseTensor(self):
-    with self.test_session():
+    with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
       indices = constant_op.constant(
           [[0], [3]], dtype=dtypes.int64, name="indices")
@@ -2398,7 +2398,7 @@
       self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
 
   def testCallGradInLoop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       i0 = constant_op.constant(0)
       params = constant_op.constant(5.0)
       params_1 = math_ops.square(params)
@@ -2417,7 +2417,7 @@
       self.assertAllClose(600.0, sess.run(output_grad)[1])
 
   def testWhileAndTensorArray(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       param = constant_op.constant(2.0)
       n0 = constant_op.constant(0)
       y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
@@ -2436,7 +2436,7 @@
       self.assertAllClose(107520.0, sess.run(r))
 
   def testWhileGrad_StopGrad(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(3.0, name="x")
       y = constant_op.constant(2.0, name="y")
 
@@ -2479,7 +2479,7 @@
       self.assertEqual(32.0, r.eval())
 
   def testWhileGrad_StopGradInside(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(3.0, name="x")
       y = constant_op.constant(2.0, name="y")
 
@@ -2498,7 +2498,7 @@
       self.assertAllClose(156.0, r.eval())
 
   def testWhileGrad_StopGradInsideNoShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       y = array_ops.placeholder(dtypes.float32)
 
@@ -2534,7 +2534,7 @@
     gradients_impl.gradients(grad_theta_stopped, theta)
 
   def testStopGradOnWhileGrad(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(2.0, name="x")
       y = constant_op.constant(2.0, name="y")
 
@@ -2562,7 +2562,7 @@
     _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
     dy_dq, = gradients_impl.gradients(y, q)
     self.assertIsNotNone(dy_dq)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(q.initializer)
       self.assertAllClose([0., 0.], sess.run(dy_dq))
 
@@ -2579,7 +2579,7 @@
     _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
     dy_dq, = gradients_impl.gradients(y, q)
     self.assertIsNotNone(dy_dq)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(q.initializer)
       self.assertAllClose([1., 1.], sess.run(dy_dq))
 
@@ -2607,7 +2607,7 @@
     self.assertIsNotNone(grad)
 
   def testStopGradMultiFlows(self):
-    with self.test_session():
+    with self.cached_session():
 
       def body(i, y, r):
         x = variable_scope.get_variable(
@@ -2633,10 +2633,10 @@
       self.assertEqual(5.0, result.eval())
 
   def testOneValueCond(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       c = array_ops.placeholder(dtypes.int32, shape=[])
       one = ops.convert_to_tensor(1, name="one")
       two = ops.convert_to_tensor(2, name="two")
@@ -2651,10 +2651,10 @@
       self.assertEqual([2], i.eval(feed_dict={c: 0}))
 
   def testExampleCond(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/111124878 (don't return tuple)")
 
-    with self.test_session():
+    with self.cached_session():
       x = ops.convert_to_tensor([-2.0, 2.0], name="x")
       d = array_ops.placeholder(dtypes.int32, shape=[])
 
@@ -2669,10 +2669,10 @@
       self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
 
   def testCase(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/112477618 (Operation returned from cond)")
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(1)
       y = constant_op.constant(2)
       z = constant_op.constant(3)
@@ -2724,10 +2724,10 @@
       self.assertAllEqual(r6.eval(), 0)
 
   def testCaseSideEffects(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/112477618 (Operation returned from cond)")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variables.Variable(-1)
       v1 = variables.Variable(-1)
       v2 = variables.Variable(-1)
@@ -2762,10 +2762,10 @@
       self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
 
   def testOneOpCond(self):
-    if control_flow_ops._ENABLE_COND_V2:
+    if control_flow_ops.ENABLE_COND_V2:
       return unittest.skip("b/113324949 (ref vars)")
 
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(0)
       c = ops.convert_to_tensor(0)
       one = ops.convert_to_tensor(1)
@@ -2793,7 +2793,7 @@
       self.assertEqual(2, v.eval())
 
   def testWithOpsDependencies(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = variables.Variable(0.0)
       c = constant_op.constant(10)
 
@@ -2816,7 +2816,7 @@
     self.assertAllClose(0.0, real_v_val)
 
   def testWithTensorDependencies(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(0.0)
       c1 = constant_op.constant(10)
       c2 = constant_op.constant(20)
@@ -2842,7 +2842,7 @@
       self.assertAllClose(0.0, v.eval())
 
   def testWithIndexedSlicesDependencies(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable(
           np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
       v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
@@ -2886,7 +2886,7 @@
         self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
 
   def testGroup(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v1 = variables.Variable([0.0])
       v2 = variables.Variable([1.0])
 
@@ -2997,7 +2997,7 @@
     self.assertEqual(None, s.get_shape())
 
   def testRunLoopTensor(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tensor_list = []
 
       def condition(t):
@@ -3021,7 +3021,7 @@
     def func(x):
       return np.square(x)
 
-    with self.test_session():
+    with self.cached_session():
       r = control_flow_ops.while_loop(
           lambda i, v: i < 4,
           lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
@@ -3035,7 +3035,7 @@
     def func(x):
       return math_ops.square(math_ops.square(x))
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(2.0, dtypes.float32)
       r = control_flow_ops.while_loop(
           lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
@@ -3174,7 +3174,7 @@
 
   def testTensors(self):
     for v1_first in [True, False]:
-      with self.test_session():
+      with self.cached_session():
         v1 = variables.Variable([1.0])
         add1 = math_ops.add(
             control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
@@ -3204,7 +3204,7 @@
 
   def testIndexedSlices(self):
     for v1_first in [True, False]:
-      with self.test_session():
+      with self.cached_session():
         v1 = variables.Variable(
             np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
                 np.float32))
@@ -3243,7 +3243,7 @@
                               v1.eval())
 
   def testAcceptTensorsAsControlInputs(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(0)
       assign = state_ops.assign(var, 1)
       t, = control_flow_ops.tuple(
@@ -3408,6 +3408,7 @@
         name="unroll_same_device", iters=iters, wall_time=duration)
 
 
+@test_util.with_cond_v2
 class EagerTest(test.TestCase):
 
   def testCond(self):
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py
index fcba456..2d6d8a8 100644
--- a/tensorflow/python/kernel_tests/conv1d_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_test.py
@@ -53,7 +53,7 @@
             self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
 
   def testConv1DTranspose(self):
-    with self.test_session():
+    with self.cached_session():
       stride = 2
 
       # Input, output: [batch, width, depth]
diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
index be299be..644a151 100644
--- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
@@ -32,7 +32,7 @@
 class Conv2DBackpropFilterGradTest(test.TestCase):
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       for padding in ["SAME", "VALID"]:
         for stride in [1, 2]:
           np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 27804be..cbdd2c5 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -37,7 +37,7 @@
 class Conv2DTransposeTest(test.TestCase):
 
   def testConv2DTransposeSingleStride(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 1, 1, 1]
 
       # Input, output: [batch, height, width, depth]
@@ -75,7 +75,7 @@
               self.assertAllClose(target, value[n, h, w, k])
 
   def testConv2DTransposeSame(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 2, 2, 1]
 
       # Input, output: [batch, height, width, depth]
@@ -108,7 +108,7 @@
               self.assertAllClose(target, value[n, h, w, k])
 
   def testConv2DTransposeValid(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 2, 2, 1]
 
       # Input, output: [batch, height, width, depth]
@@ -163,7 +163,7 @@
     np.random.seed(1)  # Make it reproducible.
     x_val = np.random.random_sample(x_shape).astype(np.float64)
     f_val = np.random.random_sample(f_shape).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
       f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
       output = nn_ops.conv2d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
index 85264ef..89b6406 100644
--- a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
@@ -32,7 +32,7 @@
 class Conv3DBackpropFilterV2GradTest(test.TestCase):
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       for padding in ["SAME", "VALID"]:
         for stride in [1, 2]:
           np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index 289ae29..2527b83 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -32,7 +32,7 @@
 class Conv3DTransposeTest(test.TestCase):
 
   def testConv3DTransposeSingleStride(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 1, 1, 1, 1]
 
       # Input, output: [batch, depth, height, width, channel]
@@ -82,7 +82,7 @@
                 self.assertAllClose(target, value[n, d, h, w, k])
 
   def testConv3DTransposeSame(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 2, 2, 2, 1]
 
       # Input, output: [batch, depth, height, width, depth]
@@ -134,7 +134,7 @@
   def testConv3DTransposeOutputShapeType(self):
     # Test case for GitHub issue 18887
     for dtype in [dtypes.int32, dtypes.int64]:
-      with self.test_session():
+      with self.cached_session():
         x_shape = [2, 5, 6, 4, 3]
         y_shape = [2, 5, 6, 4, 2]
         f_shape = [3, 3, 3, 2, 3]
@@ -149,7 +149,7 @@
         output.eval()
 
   def testConv3DTransposeValid(self):
-    with self.test_session():
+    with self.cached_session():
       strides = [1, 2, 2, 2, 1]
 
       # Input, output: [batch, depth, height, width, depth]
@@ -209,7 +209,7 @@
     np.random.seed(1)  # Make it reproducible.
     x_val = np.random.random_sample(x_shape).astype(np.float64)
     f_val = np.random.random_sample(f_shape).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
       f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
       output = nn_ops.conv3d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index 0b53112..6794464 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -108,7 +108,7 @@
             use_gpu=use_gpu)
         results.append(result)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         values = sess.run(results)
         for value in values:
           print("expected = ", expected)
@@ -183,7 +183,7 @@
         expected_results.append(expected)
         computed_results.append(computed)
         tolerance = 1e-2 if use_gpu else 1e-5
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           expected_values = sess.run(expected_results)
           computed_values = sess.run(computed_results)
           for e_value, c_value in zip(expected_values, computed_values):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 00de94f..ea61149 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1474,7 +1474,7 @@
           padding="SAME")
 
   def testOpEdgeCases(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Illegal strides.
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "strides in the batch and depth"):
@@ -1539,7 +1539,7 @@
     # numbers from 1.
     x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
     x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       t1 = constant_op.constant(x1, shape=tensor_in_sizes)
       t1.set_shape(tensor_in_sizes)
       t2 = constant_op.constant(x2, shape=filter_in_sizes)
diff --git a/tensorflow/python/kernel_tests/cross_grad_test.py b/tensorflow/python/kernel_tests/cross_grad_test.py
index f040ac6..0bd4006 100644
--- a/tensorflow/python/kernel_tests/cross_grad_test.py
+++ b/tensorflow/python/kernel_tests/cross_grad_test.py
@@ -27,7 +27,7 @@
 class CrossOpTest(test.TestCase):
 
   def testGradientRandomValues(self):
-    with self.test_session():
+    with self.cached_session():
       us = [2, 3]
       u = array_ops.reshape(
           [0.854, -0.616, 0.767, 0.725, -0.927, 0.159], shape=us)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index b61232c..00d7f95 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -541,7 +541,7 @@
       return x
 
     for op, real_range in op_range:
-      with self.test_session():
+      with self.cached_session():
         for dtype, tol in dtype_tols:
           x = constant_op.constant(rand(dtype))
           y = constant_op.constant(rand(dtype))
@@ -604,7 +604,7 @@
                         numeric_gradient_type=None):
     z = np_func(x, y)
     zs = list(z.shape)
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       if x.dtype in (np.float32, np.float64):
@@ -634,7 +634,7 @@
                         numeric_gradient_type=None):
     z = np_func(x, y)
     zs = list(z.shape)
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       if x.dtype in (np.float32, np.float64):
@@ -720,7 +720,7 @@
   def testFloatDifferentShapes(self):
     x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
     y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       s = math_ops.reduce_sum(inx * iny)
@@ -736,7 +736,7 @@
     y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
     var_x = variables.Variable(x)
     var_y = variables.Variable(y)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run([var_x.initializer, var_y.initializer])
       left_result = (var_x * y).eval()
       right_result = (x * var_y).eval()
@@ -1168,7 +1168,7 @@
             ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
 
   def testZeroPowGrad(self):
-    with self.test_session():
+    with self.cached_session():
       for dtype in (np.float16, np.float32, np.float64, np.complex64,
                     np.complex128):
         x = constant_op.constant(0.0, dtype=dtype)
@@ -1178,7 +1178,7 @@
         self.assertEqual(error, 0)
 
   def testComplexPowGrad(self):
-    with self.test_session():
+    with self.cached_session():
       for dtype in np.complex64, np.complex128:
         for base in 2.0, -2.0:
           x = constant_op.constant(base, dtype=dtype)
@@ -1470,7 +1470,7 @@
     self.assertShapeEqual(np_ans, out)
 
   def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = array_ops.where(c, inx, iny)
@@ -1494,7 +1494,7 @@
       self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
 
   def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = array_ops.where(c, inx, iny)
@@ -1582,7 +1582,7 @@
     x = np.random.rand(1, 3, 0) * 100
     y = np.random.rand(1, 3, 0) * 100
     z_expected = np.zeros((1, 3, 0), dtype=np.float32)
-    with self.test_session():
+    with self.cached_session():
       xt = x.astype(np.float32)
       yt = y.astype(np.float32)
       z = array_ops.where(c, xt, yt).eval()
@@ -1590,7 +1590,7 @@
 
   def testNan(self):
     """Verify that nans don't propagate where they shouldn't."""
-    with self.test_session():
+    with self.cached_session():
       for c in False, True:
         for a in 7.0, np.nan:
           for b in 5.0, np.nan:
@@ -1614,7 +1614,7 @@
     self.assertShapeEqual(np_ans, out)
 
   def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = array_ops.where(c, inx, iny)
@@ -1638,7 +1638,7 @@
       self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
 
   def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = array_ops.where(c, inx, iny)
@@ -1745,7 +1745,7 @@
       self._compare(x.astype(t), t(y), use_gpu=True)
 
   def _compareGradientX(self, func, x, y):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = func(inx, iny)
@@ -1760,7 +1760,7 @@
       self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
 
   def _compareGradientY(self, func, x, y):
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       iny = ops.convert_to_tensor(y)
       out = func(inx, iny)
@@ -1932,7 +1932,7 @@
 
   def _compare_values(self, x, y=None):
     y = np.rint(x) if y is None else np.asarray(y)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tf_rint = math_ops.rint(x)
       np_rint = sess.run(tf_rint)
     self.assertAllEqual(y, np_rint)
@@ -1940,7 +1940,7 @@
 
   def _compare(self, x):
     np_floor, np_ceil = np.floor(x), np.ceil(x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inx = ops.convert_to_tensor(x)
       ofloor, oceil = math_ops.floor(inx), math_ops.ceil(inx)
       tf_floor, tf_ceil = sess.run([ofloor, oceil])
@@ -2099,7 +2099,7 @@
     # computes the squared sum. This is obviously the same as sum(real
     # * real) + sum(imag * imag). We just want to make sure the
     # gradient function is checked.
-    with self.test_session():
+    with self.cached_session():
       inx = ops.convert_to_tensor(x)
       real, imag = array_ops.split(value=inx, num_or_size_splits=2, axis=1)
       real, imag = array_ops.reshape(real, [-1]), array_ops.reshape(imag, [-1])
@@ -2116,7 +2116,7 @@
   def _compareBroadcastGradient(self, x):
     x_ = ops.convert_to_tensor(x)
     epsilon = 1e-3
-    with self.test_session():
+    with self.cached_session():
       for args in [(x_, 0.), (0., x_)]:
         z = math_ops.reduce_sum(math_ops.abs(math_ops.complex(*args)))
         jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -2136,7 +2136,7 @@
     # data is a float matrix of shape [n, 4].  data[:, 0], data[:, 1],
     # data[:, 2], data[:, 3] are real parts of x, imaginary parts of
     # x, real parts of y and imaginary parts of y.
-    with self.test_session():
+    with self.cached_session():
       inp = ops.convert_to_tensor(data)
       xr, xi, yr, yi = array_ops.split(value=inp, num_or_size_splits=4, axis=1)
 
@@ -2166,7 +2166,7 @@
 class AccumulateTest(test.TestCase):
 
   def testSimple(self):
-    with self.test_session():
+    with self.cached_session():
       random_arrays = [
           np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
       ]
@@ -2181,20 +2181,20 @@
       self.assertAllClose(np_val, tf_val.eval())
 
   def testZeroArgs(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         tf_val = math_ops.accumulate_n([])
         tf_val.eval()
 
   def testWrongShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         a = variables.Variable(0.2)
         b = variables.Variable(0.1)
         math_ops.accumulate_n([a, b], shape=[2, 2])  # Should be shape=[]
 
   def testWrongType(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         a = variables.Variable(0.2, dtype=np.float32)
         b = variables.Variable(0.1, dtype=np.float32)
@@ -2202,7 +2202,7 @@
 
   def testWrongTypeOneInput(self):
     # Scenario that used to trigger a bug, even when testWrongType() worked
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         a = variables.Variable(0.2, dtype=np.float32)
         math_ops.accumulate_n([a], tensor_dtype=np.int32)
@@ -2214,7 +2214,7 @@
     x = np.random.rand(2, 2).astype(dtype)
     coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)]
     np_val = np.polyval(coeffs, x)
-    with self.test_session():
+    with self.cached_session():
       tf_val = math_ops.polyval(coeffs, x)
       self.assertAllClose(np_val, tf_val.eval())
 
@@ -2237,7 +2237,7 @@
             for _ in range(degree + 1)
         ]
         np_val = np.polyval(coeffs, x)
-        with self.test_session():
+        with self.cached_session():
           tf_val = math_ops.polyval(coeffs, x)
           self.assertAllClose(np_val, tf_val.eval())
 
@@ -2245,7 +2245,7 @@
     x = np.random.rand(2, 2).astype(np.float32)
     coeffs = []
     np_val = np.polyval(coeffs, x)
-    with self.test_session():
+    with self.cached_session():
       tf_val = math_ops.polyval(coeffs, x)
       self.assertAllClose(np_val, tf_val.eval())
 
diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
index 35f8f76..eebaffb 100644
--- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
@@ -60,7 +60,7 @@
     img_in = constant_op.constant(byte_string, dtype=dtypes.string)
     decode = array_ops.squeeze(image_ops.decode_bmp(img_in))
 
-    with self.test_session():
+    with self.cached_session():
       decoded = decode.eval()
       self.assertAllEqual(decoded, img_bytes)
 
@@ -135,7 +135,7 @@
     img_in = constant_op.constant(byte_string, dtype=dtypes.string)
     decode = image_ops.decode_bmp(img_in)
 
-    with self.test_session():
+    with self.cached_session():
       decoded = decode.eval()
       self.assertAllEqual(decoded, img_bytes)
 
diff --git a/tensorflow/python/kernel_tests/decode_compressed_op_test.py b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
index c9bda58..1cc1c7d 100644
--- a/tensorflow/python/kernel_tests/decode_compressed_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
@@ -44,7 +44,7 @@
 
   def testDecompress(self):
     for compression_type in ["ZLIB", "GZIP", ""]:
-      with self.test_session():
+      with self.cached_session():
         in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
         decompressed = parsing_ops.decode_compressed(
             in_bytes, compression_type=compression_type)
@@ -57,7 +57,7 @@
 
   def testDecompressWithRaw(self):
     for compression_type in ["ZLIB", "GZIP", ""]:
-      with self.test_session():
+      with self.cached_session():
         in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
         decompressed = parsing_ops.decode_compressed(
             in_bytes, compression_type=compression_type)
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 4f49d72..e9307a6 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -20,28 +20,30 @@
 
 import numpy as np
 
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 
 
+@test_util.run_all_in_graph_and_eager_modes
 class DecodeCSVOpTest(test.TestCase):
 
   def _test(self, args, expected_out=None, expected_err_re=None):
-    with self.test_session() as sess:
+    if expected_err_re is None:
       decode = parsing_ops.decode_csv(**args)
+      out = self.evaluate(decode)
 
-      if expected_err_re is None:
-        out = sess.run(decode)
-
-        for i, field in enumerate(out):
-          if field.dtype == np.float32 or field.dtype == np.float64:
-            self.assertAllClose(field, expected_out[i])
-          else:
-            self.assertAllEqual(field, expected_out[i])
-
-      else:
-        with self.assertRaisesOpError(expected_err_re):
-          sess.run(decode)
+      for i, field in enumerate(out):
+        if field.dtype == np.float32 or field.dtype == np.float64:
+          self.assertAllClose(field, expected_out[i])
+        else:
+          self.assertAllEqual(field, expected_out[i])
+    else:
+      with self.assertRaisesOpError(expected_err_re):
+        decode = parsing_ops.decode_csv(**args)
+        self.evaluate(decode)
 
   def testSimple(self):
     args = {
@@ -53,6 +55,31 @@
 
     self._test(args, expected_out)
 
+  def testSimpleWithScalarDefaults(self):
+    args = {
+        "records": ["1,4", "2,5", "3,6"],
+        "record_defaults": [1, 2],
+    }
+
+    expected_out = [[1, 2, 3], [4, 5, 6]]
+
+    self._test(args, expected_out)
+
+  def testSimpleWith2DDefaults(self):
+    args = {
+        "records": ["1", "2", "3"],
+        "record_defaults": [[[0]]],
+    }
+
+    if context.executing_eagerly():
+      err_spec = errors.InvalidArgumentError, (
+          "Each record default should be at "
+          "most rank 1.")
+    else:
+      err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
+    with self.assertRaisesWithPredicateMatch(*err_spec):
+      self._test(args)
+
   def testSimpleNoQuoteDelimiter(self):
     args = {
         "records": ["1", "2", '"3"'],
diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py
index 5828043..7f73fba 100644
--- a/tensorflow/python/kernel_tests/decode_image_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_image_op_test.py
@@ -111,7 +111,7 @@
   def testInvalidBytes(self):
     image_bytes = b"ThisIsNotAnImage!"
     decode = image_ops.decode_image(image_bytes)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         decode.eval()
 
diff --git a/tensorflow/python/kernel_tests/decode_png_op_test.py b/tensorflow/python/kernel_tests/decode_png_op_test.py
index d2e0393..8f36343 100644
--- a/tensorflow/python/kernel_tests/decode_png_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_png_op_test.py
@@ -46,7 +46,7 @@
         image_ops.decode_png(
             img_in, dtype=dtypes.uint16))
 
-    with self.test_session():
+    with self.cached_session():
       decoded = decode.eval()
       self.assertAllEqual(decoded, img_bytes)
 
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
index 122a9ed..dcc9848 100644
--- a/tensorflow/python/kernel_tests/decode_raw_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -29,7 +29,7 @@
 class DecodeRawOpTest(test.TestCase):
 
   def testToUint8(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
       decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint8)
       self.assertEqual([2, None], decode.get_shape().as_list())
@@ -47,7 +47,7 @@
         decode.eval(feed_dict={in_bytes: ["short", "longer"]})
 
   def testToInt16(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
       decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.int16)
       self.assertEqual([None, None], decode.get_shape().as_list())
@@ -62,7 +62,7 @@
         decode.eval(feed_dict={in_bytes: ["123", "456"]})
 
   def testEndianness(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
       decode_le = parsing_ops.decode_raw(
           in_bytes, out_type=dtypes.int32, little_endian=True)
@@ -74,18 +74,18 @@
       self.assertAllEqual([[0x01020304]], result)
 
   def testToFloat16(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
       decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
       self.assertEqual([None, None], decode.get_shape().as_list())
 
-      expected_result = np.matrix([[1, -2, -3, 4]], dtype=np.float16)
+      expected_result = np.matrix([[1, -2, -3, 4]], dtype="<f2")
       result = decode.eval(feed_dict={in_bytes: [expected_result.tostring()]})
 
       self.assertAllEqual(expected_result, result)
 
   def testEmptyStringInput(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
       decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
 
@@ -94,7 +94,7 @@
         self.assertEqual((num_inputs, 0), result.shape)
 
   def testToUInt16(self):
-    with self.test_session():
+    with self.cached_session():
       in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
       decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
       self.assertEqual([None, None], decode.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
index d33bf1b..affbaf1 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -33,7 +33,7 @@
   #   contain benign and deliberate data races when multiple threads update
   #   the same parameters without a lock.
   def testParallelUpdateWithoutLocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ones_t = array_ops.fill([1024, 1024], 1.0)
       p = variables.Variable(array_ops.zeros([1024, 1024]))
       adds = [
@@ -60,7 +60,7 @@
       self.assertTrue((vals <= ones * 20).all())
 
   def testParallelAssignWithoutLocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ones_t = array_ops.fill([1024, 1024], float(1))
       p = variables.Variable(array_ops.zeros([1024, 1024]))
       assigns = [
@@ -92,7 +92,7 @@
   # returning the output tensors. This issue will be resolved with the new
   # resource variables.
   def testParallelUpdateWithLocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       zeros_t = array_ops.fill([1024, 1024], 0.0)
       ones_t = array_ops.fill([1024, 1024], 1.0)
       p = variables.Variable(zeros_t)
@@ -119,7 +119,7 @@
       self.assertAllEqual(vals, ones * 20)
 
   def testParallelAssignWithLocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       zeros_t = array_ops.fill([1024, 1024], 0.0)
       ones_t = array_ops.fill([1024, 1024], 1.0)
       p = variables.Variable(zeros_t)
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 4dda9f0..06c3271 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -85,7 +85,7 @@
     self._testTypes(np.arange(0, 20).reshape([4, 5]))
 
   def testAssignNonStrictShapeChecking(self):
-    with self.test_session():
+    with self.cached_session():
       data = array_ops.fill([1024, 1024], 0)
       p = variables.Variable([1])
       a = state_ops.assign(p, data, validate_shape=False)
@@ -99,14 +99,14 @@
       self.assertAllEqual(p.eval(), data2.eval())
 
   def testInitRequiredAssignAdd(self):
-    with self.test_session():
+    with self.cached_session():
       p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
       a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
       with self.assertRaisesOpError("use uninitialized"):
         a.op.run()
 
   def testInitRequiredAssignSub(self):
-    with self.test_session():
+    with self.cached_session():
       p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
       a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
       with self.assertRaisesOpError("use uninitialized"):
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 5884555..5741f2e 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -205,6 +205,19 @@
             use_gpu=True,
             grouped_conv=True)
 
+  def testDepthwiseConv2DWithUnknownShape(self):
+    # GitHub issue 22110.
+    if not test.is_gpu_available():
+      return
+    with self.test_session(use_gpu=True):
+      x = array_ops.placeholder(dtypes.float32)
+      f = np.ones([1, 1, 1, 1], np.float32)
+      v = nn_impl.depthwise_conv2d(
+          x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW")
+      self.assertAllEqual(
+          np.ones([1, 1, 1, 1], np.float32),
+          v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)}))
+
   def testDepthwiseConv2DFormat(self):
     if not test.is_gpu_available():
       return
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 9ad77a5..26d013b 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -62,59 +62,50 @@
   def testP(self):
     p = [0.2, 0.4]
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(p, self.evaluate(dist.probs))
+    self.assertAllClose(p, self.evaluate(dist.probs))
 
   @test_util.run_in_graph_and_eager_modes
   def testLogits(self):
     logits = [-42., 42.]
     dist = bernoulli.Bernoulli(logits=logits)
-    with self.test_session():
-      self.assertAllClose(logits, self.evaluate(dist.logits))
+    self.assertAllClose(logits, self.evaluate(dist.logits))
 
     if not special:
       return
 
-    with self.test_session():
-      self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
+    self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
 
     p = [0.01, 0.99, 0.42]
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
+    self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
 
   @test_util.run_in_graph_and_eager_modes
   def testInvalidP(self):
     invalid_ps = [1.01, 2.]
     for p in invalid_ps:
-      with self.test_session():
-        with self.assertRaisesOpError("probs has components greater than 1"):
-          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-          self.evaluate(dist.probs)
+      with self.assertRaisesOpError("probs has components greater than 1"):
+        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+        self.evaluate(dist.probs)
 
     invalid_ps = [-0.01, -3.]
     for p in invalid_ps:
-      with self.test_session():
-        with self.assertRaisesOpError("Condition x >= 0"):
-          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-          self.evaluate(dist.probs)
+      with self.assertRaisesOpError("Condition x >= 0"):
+        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+        self.evaluate(dist.probs)
 
     valid_ps = [0.0, 0.5, 1.0]
     for p in valid_ps:
-      with self.test_session():
-        dist = bernoulli.Bernoulli(probs=p)
-        self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
+      dist = bernoulli.Bernoulli(probs=p)
+      self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
 
   @test_util.run_in_graph_and_eager_modes
   def testShapes(self):
-    with self.test_session():
-      for batch_shape in ([], [1], [2, 3, 4]):
-        dist = make_bernoulli(batch_shape)
-        self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
-        self.assertAllEqual(batch_shape,
-                            self.evaluate(dist.batch_shape_tensor()))
-        self.assertAllEqual([], dist.event_shape.as_list())
-        self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    for batch_shape in ([], [1], [2, 3, 4]):
+      dist = make_bernoulli(batch_shape)
+      self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
+      self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
+      self.assertAllEqual([], dist.event_shape.as_list())
+      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
 
   @test_util.run_in_graph_and_eager_modes
   def testDtype(self):
@@ -137,31 +128,29 @@
   @test_util.run_in_graph_and_eager_modes
   def _testPmf(self, **kwargs):
     dist = bernoulli.Bernoulli(**kwargs)
-    with self.test_session():
-      # pylint: disable=bad-continuation
-      xs = [
-          0,
-          [1],
-          [1, 0],
-          [[1, 0]],
-          [[1, 0], [1, 1]],
-      ]
-      expected_pmfs = [
-          [[0.8, 0.6], [0.7, 0.4]],
-          [[0.2, 0.4], [0.3, 0.6]],
-          [[0.2, 0.6], [0.3, 0.4]],
-          [[0.2, 0.6], [0.3, 0.4]],
-          [[0.2, 0.6], [0.3, 0.6]],
-      ]
-      # pylint: enable=bad-continuation
+    # pylint: disable=bad-continuation
+    xs = [
+        0,
+        [1],
+        [1, 0],
+        [[1, 0]],
+        [[1, 0], [1, 1]],
+    ]
+    expected_pmfs = [
+        [[0.8, 0.6], [0.7, 0.4]],
+        [[0.2, 0.4], [0.3, 0.6]],
+        [[0.2, 0.6], [0.3, 0.4]],
+        [[0.2, 0.6], [0.3, 0.4]],
+        [[0.2, 0.6], [0.3, 0.6]],
+    ]
+    # pylint: enable=bad-continuation
 
-      for x, expected_pmf in zip(xs, expected_pmfs):
-        self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
-        self.assertAllClose(
-            self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
+    for x, expected_pmf in zip(xs, expected_pmfs):
+      self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
+      self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
 
   def testPmfCorrectBroadcastDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtype=dtypes.float32)
       dist = bernoulli.Bernoulli(probs=p)
       event1 = [1, 0, 1]
@@ -178,12 +167,11 @@
   @test_util.run_in_graph_and_eager_modes
   def testPmfInvalid(self):
     p = [0.1, 0.2, 0.7]
-    with self.test_session():
-      dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-      with self.assertRaisesOpError("must be non-negative."):
-        self.evaluate(dist.prob([1, 1, -1]))
-      with self.assertRaisesOpError("Elements cannot exceed 1."):
-        self.evaluate(dist.prob([2, 0, 1]))
+    dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+    with self.assertRaisesOpError("must be non-negative."):
+      self.evaluate(dist.prob([1, 1, -1]))
+    with self.assertRaisesOpError("Elements cannot exceed 1."):
+      self.evaluate(dist.prob([2, 0, 1]))
 
   @test_util.run_in_graph_and_eager_modes
   def testPmfWithP(self):
@@ -194,7 +182,7 @@
     self._testPmf(logits=special.logit(p))
 
   def testBroadcasting(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes.float32)
       dist = bernoulli.Bernoulli(probs=p)
       self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
@@ -208,70 +196,63 @@
           }))
 
   def testPmfShapes(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
       dist = bernoulli.Bernoulli(probs=p)
       self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=0.5)
       self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=0.5)
       self.assertEqual((), dist.log_prob(1).get_shape())
       self.assertEqual((1), dist.log_prob([1]).get_shape())
       self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
       self.assertEqual((2, 1), dist.log_prob(1).get_shape())
 
   @test_util.run_in_graph_and_eager_modes
   def testBoundaryConditions(self):
-    with self.test_session():
-      dist = bernoulli.Bernoulli(probs=1.0)
-      self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
-      self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
+    dist = bernoulli.Bernoulli(probs=1.0)
+    self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
+    self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
 
   @test_util.run_in_graph_and_eager_modes
   def testEntropyNoBatch(self):
     p = 0.2
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
+    self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
 
   @test_util.run_in_graph_and_eager_modes
   def testEntropyWithBatch(self):
     p = [[0.1, 0.7], [0.2, 0.6]]
     dist = bernoulli.Bernoulli(probs=p, validate_args=False)
-    with self.test_session():
-      self.assertAllClose(
-          self.evaluate(dist.entropy()),
-          [[entropy(0.1), entropy(0.7)], [entropy(0.2),
-                                          entropy(0.6)]])
+    self.assertAllClose(
+        self.evaluate(dist.entropy()),
+        [[entropy(0.1), entropy(0.7)], [entropy(0.2),
+                                        entropy(0.6)]])
 
   @test_util.run_in_graph_and_eager_modes
   def testSampleN(self):
-    with self.test_session():
-      p = [0.2, 0.6]
-      dist = bernoulli.Bernoulli(probs=p)
-      n = 100000
-      samples = dist.sample(n)
-      samples.set_shape([n, 2])
-      self.assertEqual(samples.dtype, dtypes.int32)
-      sample_values = self.evaluate(samples)
-      self.assertTrue(np.all(sample_values >= 0))
-      self.assertTrue(np.all(sample_values <= 1))
-      # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
-      # n). This means that the tolerance is very sensitive to the value of p
-      # as well as n.
-      self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
-      self.assertEqual(set([0, 1]), set(sample_values.flatten()))
-      # In this test we're just interested in verifying there isn't a crash
-      # owing to mismatched types. b/30940152
-      dist = bernoulli.Bernoulli(np.log([.2, .4]))
-      self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+    p = [0.2, 0.6]
+    dist = bernoulli.Bernoulli(probs=p)
+    n = 100000
+    samples = dist.sample(n)
+    samples.set_shape([n, 2])
+    self.assertEqual(samples.dtype, dtypes.int32)
+    sample_values = self.evaluate(samples)
+    self.assertTrue(np.all(sample_values >= 0))
+    self.assertTrue(np.all(sample_values <= 1))
+    # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
+    # n). This means that the tolerance is very sensitive to the value of p
+    # as well as n.
+    self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
+    self.assertEqual(set([0, 1]), set(sample_values.flatten()))
+    # In this test we're just interested in verifying there isn't a crash
+    # owing to mismatched types. b/30940152
+    dist = bernoulli.Bernoulli(np.log([.2, .4]))
+    self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
 
   @test_util.run_in_graph_and_eager_modes
   def testNotReparameterized(self):
@@ -284,7 +265,7 @@
     self.assertIsNone(grad_p)
 
   def testSampleActsLikeSampleN(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       p = [0.2, 0.6]
       dist = bernoulli.Bernoulli(probs=p)
       n = 1000
@@ -299,27 +280,24 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testMean(self):
-    with self.test_session():
-      p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
-      dist = bernoulli.Bernoulli(probs=p)
-      self.assertAllEqual(self.evaluate(dist.mean()), p)
+    p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
+    dist = bernoulli.Bernoulli(probs=p)
+    self.assertAllEqual(self.evaluate(dist.mean()), p)
 
   @test_util.run_in_graph_and_eager_modes
   def testVarianceAndStd(self):
     var = lambda p: p * (1. - p)
-    with self.test_session():
-      p = [[0.2, 0.7], [0.5, 0.4]]
-      dist = bernoulli.Bernoulli(probs=p)
-      self.assertAllClose(
-          self.evaluate(dist.variance()),
-          np.array(
-              [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
-      self.assertAllClose(
-          self.evaluate(dist.stddev()),
-          np.array(
-              [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
-               [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
-              dtype=np.float32))
+    p = [[0.2, 0.7], [0.5, 0.4]]
+    dist = bernoulli.Bernoulli(probs=p)
+    self.assertAllClose(
+        self.evaluate(dist.variance()),
+        np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
+                 dtype=np.float32))
+    self.assertAllClose(
+        self.evaluate(dist.stddev()),
+        np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
+                  [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
+                 dtype=np.float32))
 
   @test_util.run_in_graph_and_eager_modes
   def testBernoulliBernoulliKL(self):
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index 36f3ffc..d580a41 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -20,7 +20,6 @@
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import random_seed
@@ -51,237 +50,215 @@
 class BetaTest(test.TestCase):
 
   def testSimpleShapes(self):
-    with self.test_session():
-      a = np.random.rand(3)
-      b = np.random.rand(3)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+    a = np.random.rand(3)
+    b = np.random.rand(3)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
 
   def testComplexShapes(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2)
-      b = np.random.rand(3, 2, 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(
-          tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+    a = np.random.rand(3, 2, 2)
+    b = np.random.rand(3, 2, 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
 
   def testComplexShapesBroadcast(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2)
-      b = np.random.rand(2, 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(
-          tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+    a = np.random.rand(3, 2, 2)
+    b = np.random.rand(2, 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
 
   def testAlphaProperty(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual([1, 3], dist.concentration1.get_shape())
-      self.assertAllClose(a, self.evaluate(dist.concentration1))
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual([1, 3], dist.concentration1.get_shape())
+    self.assertAllClose(a, self.evaluate(dist.concentration1))
 
   def testBetaProperty(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual([1, 3], dist.concentration0.get_shape())
-      self.assertAllClose(b, self.evaluate(dist.concentration0))
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual([1, 3], dist.concentration0.get_shape())
+    self.assertAllClose(b, self.evaluate(dist.concentration0))
 
   def testPdfXProper(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b, validate_args=True)
-      self.evaluate(dist.prob([.1, .3, .6]))
-      self.evaluate(dist.prob([.2, .3, .5]))
-      # Either condition can trigger.
-      with self.assertRaisesOpError("sample must be positive"):
-        self.evaluate(dist.prob([-1., 0.1, 0.5]))
-      with self.assertRaisesOpError("sample must be positive"):
-        self.evaluate(dist.prob([0., 0.1, 0.5]))
-      with self.assertRaisesOpError("sample must be less than `1`"):
-        self.evaluate(dist.prob([.1, .2, 1.2]))
-      with self.assertRaisesOpError("sample must be less than `1`"):
-        self.evaluate(dist.prob([.1, .2, 1.0]))
+    dist = beta_lib.Beta(a, b, validate_args=True)
+    self.evaluate(dist.prob([.1, .3, .6]))
+    self.evaluate(dist.prob([.2, .3, .5]))
+    # Either condition can trigger.
+    with self.assertRaisesOpError("sample must be positive"):
+      self.evaluate(dist.prob([-1., 0.1, 0.5]))
+    with self.assertRaisesOpError("sample must be positive"):
+      self.evaluate(dist.prob([0., 0.1, 0.5]))
+    with self.assertRaisesOpError("sample must be less than `1`"):
+      self.evaluate(dist.prob([.1, .2, 1.2]))
+    with self.assertRaisesOpError("sample must be less than `1`"):
+      self.evaluate(dist.prob([.1, .2, 1.0]))
 
   def testPdfTwoBatches(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [.5, .5]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2,), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [.5, .5]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2,), pdf.get_shape())
 
   def testPdfTwoBatchesNontrivialX(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [.3, .7]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
-      self.assertEqual((2,), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [.3, .7]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
+    self.assertEqual((2,), pdf.get_shape())
 
   def testPdfUniformZeroBatch(self):
-    with self.test_session():
-      # This is equivalent to a uniform distribution
-      a = 1.
-      b = 1.
-      x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1.] * 5, self.evaluate(pdf))
-      self.assertEqual((5,), pdf.get_shape())
+    # This is equivalent to a uniform distribution
+    a = 1.
+    b = 1.
+    x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1.] * 5, self.evaluate(pdf))
+    self.assertEqual((5,), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      a = [[1., 2]]
-      b = [[1., 2]]
-      x = [[.5, .5], [.3, .7]]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2]]
+    b = [[1., 2]]
+    x = [[.5, .5], [.3, .7]]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [[.5, .5], [.2, .8]]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [[.5, .5], [.2, .8]]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      a = [[1., 2], [2., 3]]
-      b = [[1., 2], [2., 3]]
-      x = [[.5, .5]]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2], [2., 3]]
+    b = [[1., 2], [2., 3]]
+    x = [[.5, .5]]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      a = [[1., 2], [2., 3]]
-      b = [[1., 2], [2., 3]]
-      x = [.5, .5]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2], [2., 3]]
+    b = [[1., 2], [2., 3]]
+    x = [.5, .5]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testBetaMean(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_mean = stats.beta.mean(a, b)
-      self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_mean = stats.beta.mean(a, b)
+    self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
 
   def testBetaVariance(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variance = stats.beta.var(a, b)
-      self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variance = stats.beta.var(a, b)
+    self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
 
   def testBetaMode(self):
-    with session.Session():
-      a = np.array([1.1, 2, 3])
-      b = np.array([2., 4, 1.2])
-      expected_mode = (a - 1) / (a + b - 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.mode().get_shape(), (3,))
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    a = np.array([1.1, 2, 3])
+    b = np.array([2., 4, 1.2])
+    expected_mode = (a - 1) / (a + b - 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.mode().get_shape(), (3,))
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
   def testBetaModeInvalid(self):
-    with session.Session():
-      a = np.array([1., 2, 3])
-      b = np.array([2., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dist.mode())
+    a = np.array([1., 2, 3])
+    b = np.array([2., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dist.mode())
 
-      a = np.array([2., 2, 3])
-      b = np.array([1., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dist.mode())
+    a = np.array([2., 2, 3])
+    b = np.array([1., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dist.mode())
 
   def testBetaModeEnableAllowNanStats(self):
-    with session.Session():
-      a = np.array([1., 2, 3])
-      b = np.array([2., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+    a = np.array([1., 2, 3])
+    b = np.array([2., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
 
-      expected_mode = (a - 1) / (a + b - 2)
-      expected_mode[0] = np.nan
-      self.assertEqual((3,), dist.mode().get_shape())
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    expected_mode = (a - 1) / (a + b - 2)
+    expected_mode[0] = np.nan
+    self.assertEqual((3,), dist.mode().get_shape())
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
-      a = np.array([2., 2, 3])
-      b = np.array([1., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+    a = np.array([2., 2, 3])
+    b = np.array([1., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
 
-      expected_mode = (a - 1) / (a + b - 2)
-      expected_mode[0] = np.nan
-      self.assertEqual((3,), dist.mode().get_shape())
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    expected_mode = (a - 1) / (a + b - 2)
+    expected_mode[0] = np.nan
+    self.assertEqual((3,), dist.mode().get_shape())
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
   def testBetaEntropy(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.beta.entropy(a, b)
-      self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.beta.entropy(a, b)
+    self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
 
   def testBetaSample(self):
-    with self.test_session():
-      a = 1.
-      b = 2.
-      beta = beta_lib.Beta(a, b)
-      n = constant_op.constant(100000)
-      samples = beta.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000,))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      self.assertLess(
-          stats.kstest(
-              # Beta is a univariate distribution.
-              sample_values,
-              stats.beta(a=1., b=2.).cdf)[0],
-          0.01)
-      # The standard error of the sample mean is 1 / (sqrt(18 * n))
-      self.assertAllClose(
-          sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
-      self.assertAllClose(
-          np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+    a = 1.
+    b = 2.
+    beta = beta_lib.Beta(a, b)
+    n = constant_op.constant(100000)
+    samples = beta.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000,))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    self.assertLess(
+        stats.kstest(
+            # Beta is a univariate distribution.
+            sample_values,
+            stats.beta(a=1., b=2.).cdf)[0],
+        0.01)
+    # The standard error of the sample mean is 1 / (sqrt(18 * n))
+    self.assertAllClose(
+        sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
+    self.assertAllClose(
+        np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
 
   def testBetaFullyReparameterized(self):
     a = constant_op.constant(1.0)
@@ -297,78 +274,71 @@
 
   # Test that sampling with the same seed twice gives the same results.
   def testBetaSampleMultipleTimes(self):
-    with self.test_session():
-      a_val = 1.
-      b_val = 2.
-      n_val = 100
+    a_val = 1.
+    b_val = 2.
+    n_val = 100
 
-      random_seed.set_random_seed(654321)
-      beta1 = beta_lib.Beta(concentration1=a_val,
-                            concentration0=b_val,
-                            name="beta1")
-      samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
+    random_seed.set_random_seed(654321)
+    beta1 = beta_lib.Beta(
+        concentration1=a_val, concentration0=b_val, name="beta1")
+    samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
 
-      random_seed.set_random_seed(654321)
-      beta2 = beta_lib.Beta(concentration1=a_val,
-                            concentration0=b_val,
-                            name="beta2")
-      samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
+    random_seed.set_random_seed(654321)
+    beta2 = beta_lib.Beta(
+        concentration1=a_val, concentration0=b_val, name="beta2")
+    samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
 
-      self.assertAllClose(samples1, samples2)
+    self.assertAllClose(samples1, samples2)
 
   def testBetaSampleMultidimensional(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2).astype(np.float32)
-      b = np.random.rand(3, 2, 2).astype(np.float32)
-      beta = beta_lib.Beta(a, b)
-      n = constant_op.constant(100000)
-      samples = beta.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values[:, 1, :].mean(axis=0),
-          stats.beta.mean(a, b)[1, :],
-          atol=1e-1)
+    a = np.random.rand(3, 2, 2).astype(np.float32)
+    b = np.random.rand(3, 2, 2).astype(np.float32)
+    beta = beta_lib.Beta(a, b)
+    n = constant_op.constant(100000)
+    samples = beta.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values[:, 1, :].mean(axis=0),
+        stats.beta.mean(a, b)[1, :],
+        atol=1e-1)
 
   def testBetaCdf(self):
-    with self.test_session():
-      shape = (30, 40, 50)
-      for dt in (np.float32, np.float64):
-        a = 10. * np.random.random(shape).astype(dt)
-        b = 10. * np.random.random(shape).astype(dt)
-        x = np.random.random(shape).astype(dt)
-        actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
-        if not stats:
-          return
-        self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+    shape = (30, 40, 50)
+    for dt in (np.float32, np.float64):
+      a = 10. * np.random.random(shape).astype(dt)
+      b = 10. * np.random.random(shape).astype(dt)
+      x = np.random.random(shape).astype(dt)
+      actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+      if not stats:
+        return
+      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
 
   def testBetaLogCdf(self):
-    with self.test_session():
-      shape = (30, 40, 50)
-      for dt in (np.float32, np.float64):
-        a = 10. * np.random.random(shape).astype(dt)
-        b = 10. * np.random.random(shape).astype(dt)
-        x = np.random.random(shape).astype(dt)
-        actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
-        if not stats:
-          return
-        self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+    shape = (30, 40, 50)
+    for dt in (np.float32, np.float64):
+      a = 10. * np.random.random(shape).astype(dt)
+      b = 10. * np.random.random(shape).astype(dt)
+      x = np.random.random(shape).astype(dt)
+      actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+      if not stats:
+        return
+      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
 
   def testBetaWithSoftplusConcentration(self):
-    with self.test_session():
-      a, b = -4.2, -9.1
-      dist = beta_lib.BetaWithSoftplusConcentration(a, b)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
+    a, b = -4.2, -9.1
+    dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
 
   def testBetaBetaKL(self):
     for shape in [(10,), (4, 5)]:
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index 8b11556..e20f59f 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -36,11 +36,10 @@
   """Tests properties of the Bijector base-class."""
 
   def testIsAbstract(self):
-    with self.test_session():
-      with self.assertRaisesRegexp(TypeError,
-                                   ("Can't instantiate abstract class Bijector "
-                                    "with abstract methods __init__")):
-        bijector.Bijector()  # pylint: disable=abstract-class-instantiated
+    with self.assertRaisesRegexp(TypeError,
+                                 ("Can't instantiate abstract class Bijector "
+                                  "with abstract methods __init__")):
+      bijector.Bijector()  # pylint: disable=abstract-class-instantiated
 
   def testDefaults(self):
     class _BareBonesBijector(bijector.Bijector):
@@ -136,7 +135,7 @@
   def testBijectorDynamicEventNdims(self):
     bij = BrokenBijector(validate_args=True)
     event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Expected scalar"):
         bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
             event_ndims: (1, 2)})
@@ -308,7 +307,7 @@
     event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
     bij = ExpOnlyJacobian(forward_min_event_ndims=1)
     bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
                       feed_dict={event_ndims: 1})
     self.assertAllClose(-np.log(x_), ildj)
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 67ed044..cace5b3 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -49,115 +49,102 @@
 class DirichletTest(test.TestCase):
 
   def testSimpleShapes(self):
-    with self.test_session():
-      alpha = np.random.rand(3)
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
+    alpha = np.random.rand(3)
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
 
   def testComplexShapes(self):
-    with self.test_session():
-      alpha = np.random.rand(3, 2, 2)
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
+    alpha = np.random.rand(3, 2, 2)
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
 
   def testConcentrationProperty(self):
     alpha = [[1., 2, 3]]
-    with self.test_session():
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual([1, 3], dist.concentration.get_shape())
-      self.assertAllClose(alpha, self.evaluate(dist.concentration))
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual([1, 3], dist.concentration.get_shape())
+    self.assertAllClose(alpha, self.evaluate(dist.concentration))
 
   def testPdfXProper(self):
     alpha = [[1., 2, 3]]
-    with self.test_session():
-      dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
-      self.evaluate(dist.prob([.1, .3, .6]))
-      self.evaluate(dist.prob([.2, .3, .5]))
-      # Either condition can trigger.
-      with self.assertRaisesOpError("samples must be positive"):
-        self.evaluate(dist.prob([-1., 1.5, 0.5]))
-      with self.assertRaisesOpError("samples must be positive"):
-        self.evaluate(dist.prob([0., .1, .9]))
-      with self.assertRaisesOpError(
-          "sample last-dimension must sum to `1`"):
-        self.evaluate(dist.prob([.1, .2, .8]))
+    dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
+    self.evaluate(dist.prob([.1, .3, .6]))
+    self.evaluate(dist.prob([.2, .3, .5]))
+    # Either condition can trigger.
+    with self.assertRaisesOpError("samples must be positive"):
+      self.evaluate(dist.prob([-1., 1.5, 0.5]))
+    with self.assertRaisesOpError("samples must be positive"):
+      self.evaluate(dist.prob([0., .1, .9]))
+    with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
+      self.evaluate(dist.prob([.1, .2, .8]))
 
   def testPdfZeroBatches(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [.5, .5]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose(1., self.evaluate(pdf))
-      self.assertEqual((), pdf.get_shape())
+    alpha = [1., 2]
+    x = [.5, .5]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose(1., self.evaluate(pdf))
+    self.assertEqual((), pdf.get_shape())
 
   def testPdfZeroBatchesNontrivialX(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [.3, .7]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose(7. / 5, self.evaluate(pdf))
-      self.assertEqual((), pdf.get_shape())
+    alpha = [1., 2]
+    x = [.3, .7]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose(7. / 5, self.evaluate(pdf))
+    self.assertEqual((), pdf.get_shape())
 
   def testPdfUniformZeroBatches(self):
-    with self.test_session():
-      # Corresponds to a uniform distribution
-      alpha = [1., 1, 1]
-      x = [[.2, .5, .3], [.3, .4, .3]]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose([2., 2.], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    # Corresponds to a uniform distribution
+    alpha = [1., 1, 1]
+    x = [[.2, .5, .3], [.3, .4, .3]]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose([2., 2.], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      alpha = [[1., 2]]
-      x = [[.5, .5], [.3, .7]]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2]]
+    x = [[.5, .5], [.3, .7]]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [[.5, .5], [.2, .8]]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [1., 2]
+    x = [[.5, .5], [.2, .8]]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      alpha = [[1., 2], [2., 3]]
-      x = [[.5, .5]]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2], [2., 3]]
+    x = [[.5, .5]]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      alpha = [[1., 2], [2., 3]]
-      x = [.5, .5]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2], [2., 3]]
+    x = [.5, .5]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testMean(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.mean().get_shape(), [3])
-      if not stats:
-        return
-      expected_mean = stats.dirichlet.mean(alpha)
-      self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
+    alpha = [1., 2, 3]
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.mean().get_shape(), [3])
+    if not stats:
+      return
+    expected_mean = stats.dirichlet.mean(alpha)
+    self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
 
   def testCovarianceFromSampling(self):
     alpha = np.array([[1., 2, 3],
@@ -197,73 +184,66 @@
     self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
   def testVariance(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
-      if not stats:
-        return
-      expected_covariance = np.diag(stats.dirichlet.var(alpha))
-      expected_covariance += [[0., -2, -3], [-2, 0, -6],
-                              [-3, -6, 0]] / denominator
-      self.assertAllClose(
-          self.evaluate(dirichlet.covariance()), expected_covariance)
+    alpha = [1., 2, 3]
+    denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
+    if not stats:
+      return
+    expected_covariance = np.diag(stats.dirichlet.var(alpha))
+    expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
+                           ] / denominator
+    self.assertAllClose(
+        self.evaluate(dirichlet.covariance()), expected_covariance)
 
   def testMode(self):
-    with self.test_session():
-      alpha = np.array([1.1, 2, 3])
-      expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.mode().get_shape(), [3])
-      self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+    alpha = np.array([1.1, 2, 3])
+    expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.mode().get_shape(), [3])
+    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
 
   def testModeInvalid(self):
-    with self.test_session():
-      alpha = np.array([1., 2, 3])
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
-                                          allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dirichlet.mode())
+    alpha = np.array([1., 2, 3])
+    dirichlet = dirichlet_lib.Dirichlet(
+        concentration=alpha, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dirichlet.mode())
 
   def testModeEnableAllowNanStats(self):
-    with self.test_session():
-      alpha = np.array([1., 2, 3])
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
-                                          allow_nan_stats=True)
-      expected_mode = np.zeros_like(alpha) + np.nan
+    alpha = np.array([1., 2, 3])
+    dirichlet = dirichlet_lib.Dirichlet(
+        concentration=alpha, allow_nan_stats=True)
+    expected_mode = np.zeros_like(alpha) + np.nan
 
-      self.assertEqual(dirichlet.mode().get_shape(), [3])
-      self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+    self.assertEqual(dirichlet.mode().get_shape(), [3])
+    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
 
   def testEntropy(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.entropy().get_shape(), ())
-      if not stats:
-        return
-      expected_entropy = stats.dirichlet.entropy(alpha)
-      self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
+    alpha = [1., 2, 3]
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.entropy().get_shape(), ())
+    if not stats:
+      return
+    expected_entropy = stats.dirichlet.entropy(alpha)
+    self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
 
   def testSample(self):
-    with self.test_session():
-      alpha = [1., 2]
-      dirichlet = dirichlet_lib.Dirichlet(alpha)
-      n = constant_op.constant(100000)
-      samples = dirichlet.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertTrue(np.all(sample_values > 0.0))
-      if not stats:
-        return
-      self.assertLess(
-          stats.kstest(
-              # Beta is a univariate distribution.
-              sample_values[:, 0],
-              stats.beta(
-                  a=1., b=2.).cdf)[0],
-          0.01)
+    alpha = [1., 2]
+    dirichlet = dirichlet_lib.Dirichlet(alpha)
+    n = constant_op.constant(100000)
+    samples = dirichlet.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertTrue(np.all(sample_values > 0.0))
+    if not stats:
+      return
+    self.assertLess(
+        stats.kstest(
+            # Beta is a univariate distribution.
+            sample_values[:, 0],
+            stats.beta(a=1., b=2.).cdf)[0],
+        0.01)
 
   def testDirichletFullyReparameterized(self):
     alpha = constant_op.constant([1.0, 2.0, 3.0])
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 850da3e..27d1291 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -22,7 +22,6 @@
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_util
@@ -48,121 +47,108 @@
 class ExponentialTest(test.TestCase):
 
   def testExponentialLogPDF(self):
-    with session.Session():
-      batch_size = 6
-      lam = constant_op.constant([2.0] * batch_size)
-      lam_v = 2.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      exponential = exponential_lib.Exponential(rate=lam)
+    batch_size = 6
+    lam = constant_op.constant([2.0] * batch_size)
+    lam_v = 2.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      log_pdf = exponential.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
+    log_pdf = exponential.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
 
-      pdf = exponential.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
+    pdf = exponential.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testExponentialCDF(self):
-    with session.Session():
-      batch_size = 6
-      lam = constant_op.constant([2.0] * batch_size)
-      lam_v = 2.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    lam = constant_op.constant([2.0] * batch_size)
+    lam_v = 2.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      exponential = exponential_lib.Exponential(rate=lam)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      cdf = exponential.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
+    cdf = exponential.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
 
-      if not stats:
-        return
-      expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    if not stats:
+      return
+    expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testExponentialMean(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_mean = stats.expon.mean(scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_mean = stats.expon.mean(scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
 
   def testExponentialVariance(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variance = stats.expon.var(scale=1 / lam_v)
-      self.assertAllClose(
-          self.evaluate(exponential.variance()), expected_variance)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variance = stats.expon.var(scale=1 / lam_v)
+    self.assertAllClose(
+        self.evaluate(exponential.variance()), expected_variance)
 
   def testExponentialEntropy(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.expon.entropy(scale=1 / lam_v)
-      self.assertAllClose(
-          self.evaluate(exponential.entropy()), expected_entropy)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.expon.entropy(scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
 
   def testExponentialSample(self):
-    with self.test_session():
-      lam = constant_op.constant([3.0, 4.0])
-      lam_v = [3.0, 4.0]
-      n = constant_op.constant(100000)
-      exponential = exponential_lib.Exponential(rate=lam)
+    lam = constant_op.constant([3.0, 4.0])
+    lam_v = [3.0, 4.0]
+    n = constant_op.constant(100000)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      samples = exponential.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      for i in range(2):
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
+    samples = exponential.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    for i in range(2):
+      self.assertLess(
+          stats.kstest(sample_values[:, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
 
   def testExponentialSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 2
-      lam_v = [3.0, 22.0]
-      lam = constant_op.constant([lam_v] * batch_size)
+    batch_size = 2
+    lam_v = [3.0, 22.0]
+    lam = constant_op.constant([lam_v] * batch_size)
 
-      exponential = exponential_lib.Exponential(rate=lam)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      n = 100000
-      samples = exponential.sample(n, seed=138)
-      self.assertEqual(samples.get_shape(), (n, batch_size, 2))
+    n = 100000
+    samples = exponential.sample(n, seed=138)
+    self.assertEqual(samples.get_shape(), (n, batch_size, 2))
 
-      sample_values = self.evaluate(samples)
+    sample_values = self.evaluate(samples)
 
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      for i in range(2):
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, 0, i],
-                stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, 1, i],
-                stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    for i in range(2):
+      self.assertLess(
+          stats.kstest(sample_values[:, 0, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
+      self.assertLess(
+          stats.kstest(sample_values[:, 1, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
 
   def testFullyReparameterized(self):
     lam = constant_op.constant([0.1, 1.0])
@@ -174,11 +160,10 @@
     self.assertIsNotNone(grad_lam)
 
   def testExponentialWithSoftplusRate(self):
-    with self.test_session():
-      lam = [-2.2, -3.4]
-      exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
+    lam = [-2.2, -3.4]
+    exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 297e202..4eff40b 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -50,221 +50,203 @@
 class GammaTest(test.TestCase):
 
   def testGammaShape(self):
-    with self.test_session():
-      alpha = constant_op.constant([3.0] * 5)
-      beta = constant_op.constant(11.0)
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    alpha = constant_op.constant([3.0] * 5)
+    beta = constant_op.constant(11.0)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
 
-      self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
-      self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
-      self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
+    self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
+    self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
 
   def testGammaLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([2.0] * batch_size)
-      beta = constant_op.constant([3.0] * batch_size)
-      alpha_v = 2.0
-      beta_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
-      pdf = gamma.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    batch_size = 6
+    alpha = constant_op.constant([2.0] * batch_size)
+    beta = constant_op.constant([3.0] * batch_size)
+    alpha_v = 2.0
+    beta_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
+    pdf = gamma.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testGammaLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
-      beta = constant_op.constant([[3.0, 4.0]] * batch_size)
-      alpha_v = np.array([2.0, 4.0])
-      beta_v = np.array([3.0, 4.0])
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = gamma.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    batch_size = 6
+    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+    beta = constant_op.constant([[3.0, 4.0]] * batch_size)
+    alpha_v = np.array([2.0, 4.0])
+    beta_v = np.array([3.0, 4.0])
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = gamma.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testGammaLogPDFMultidimensionalBroadcasting(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
-      beta = constant_op.constant(3.0)
-      alpha_v = np.array([2.0, 4.0])
-      beta_v = 3.0
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = gamma.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
+    batch_size = 6
+    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+    beta = constant_op.constant(3.0)
+    alpha_v = np.array([2.0, 4.0])
+    beta_v = 3.0
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = gamma.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testGammaCDF(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([2.0] * batch_size)
-      beta = constant_op.constant([3.0] * batch_size)
-      alpha_v = 2.0
-      beta_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    alpha = constant_op.constant([2.0] * batch_size)
+    beta = constant_op.constant([3.0] * batch_size)
+    alpha_v = 2.0
+    beta_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      cdf = gamma.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    cdf = gamma.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testGammaMean(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
 
   def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
-    with self.test_session():
-      alpha_v = np.array([5.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      expected_modes = (alpha_v - 1) / beta_v
-      self.assertEqual(gamma.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+    alpha_v = np.array([5.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    expected_modes = (alpha_v - 1) / beta_v
+    self.assertEqual(gamma.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
 
   def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
-    with self.test_session():
-      # Mode will not be defined for the first entry.
-      alpha_v = np.array([0.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v,
-                              rate=beta_v,
-                              allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(gamma.mode())
+    # Mode will not be defined for the first entry.
+    alpha_v = np.array([0.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(
+        concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(gamma.mode())
 
   def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
-    with self.test_session():
-      # Mode will not be defined for the first entry.
-      alpha_v = np.array([0.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v,
-                              rate=beta_v,
-                              allow_nan_stats=True)
-      expected_modes = (alpha_v - 1) / beta_v
-      expected_modes[0] = np.nan
-      self.assertEqual(gamma.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+    # Mode will not be defined for the first entry.
+    alpha_v = np.array([0.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(
+        concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
+    expected_modes = (alpha_v - 1) / beta_v
+    expected_modes[0] = np.nan
+    self.assertEqual(gamma.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
 
   def testGammaVariance(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
 
   def testGammaStd(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.stddev().get_shape(), (3,))
-      if not stats:
-        return
-      expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
-      self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.stddev().get_shape(), (3,))
+    if not stats:
+      return
+    expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
+    self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
 
   def testGammaEntropy(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
 
   def testGammaSampleSmallAlpha(self):
-    with self.test_session():
-      alpha_v = 0.05
-      beta_v = 1.0
-      alpha = constant_op.constant(alpha_v)
-      beta = constant_op.constant(beta_v)
-      n = 100000
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.gamma.mean(
-              alpha_v, scale=1 / beta_v),
-          atol=.01)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.gamma.var(alpha_v, scale=1 / beta_v),
-          atol=.15)
+    alpha_v = 0.05
+    beta_v = 1.0
+    alpha = constant_op.constant(alpha_v)
+    beta = constant_op.constant(beta_v)
+    n = 100000
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.gamma.mean(alpha_v, scale=1 / beta_v),
+        atol=.01)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.gamma.var(alpha_v, scale=1 / beta_v),
+        atol=.15)
 
   def testGammaSample(self):
-    with self.test_session():
-      alpha_v = 4.0
-      beta_v = 3.0
-      alpha = constant_op.constant(alpha_v)
-      beta = constant_op.constant(beta_v)
-      n = 100000
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.gamma.mean(
-              alpha_v, scale=1 / beta_v),
-          atol=.01)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.gamma.var(alpha_v, scale=1 / beta_v),
-          atol=.15)
+    alpha_v = 4.0
+    beta_v = 3.0
+    alpha = constant_op.constant(alpha_v)
+    beta = constant_op.constant(beta_v)
+    n = 100000
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.gamma.mean(alpha_v, scale=1 / beta_v),
+        atol=.01)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.gamma.var(alpha_v, scale=1 / beta_v),
+        atol=.15)
 
   def testGammaFullyReparameterized(self):
     alpha = constant_op.constant(4.0)
@@ -279,37 +261,37 @@
     self.assertIsNotNone(grad_beta)
 
   def testGammaSampleMultiDimensional(self):
-    with self.test_session():
-      alpha_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
-      beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      n = 10000
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n, 10, 100))
-      self.assertEqual(sample_values.shape, (n, 10, 100))
-      zeros = np.zeros_like(alpha_v + beta_v)  # 10 x 100
-      alpha_bc = alpha_v + zeros
-      beta_bc = beta_v + zeros
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(axis=0),
-          stats.gamma.mean(
-              alpha_bc, scale=1 / beta_bc),
-          atol=0., rtol=.05)
-      self.assertAllClose(
-          sample_values.var(axis=0),
-          stats.gamma.var(alpha_bc, scale=1 / beta_bc),
-          atol=10.0, rtol=0.)
-      fails = 0
-      trials = 0
-      for ai, a in enumerate(np.reshape(alpha_v, [-1])):
-        for bi, b in enumerate(np.reshape(beta_v, [-1])):
-          s = sample_values[:, bi, ai]
-          trials += 1
-          fails += 0 if self._kstest(a, b, s) else 1
-      self.assertLess(fails, trials * 0.03)
+    alpha_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
+    beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    n = 10000
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n, 10, 100))
+    self.assertEqual(sample_values.shape, (n, 10, 100))
+    zeros = np.zeros_like(alpha_v + beta_v)  # 10 x 100
+    alpha_bc = alpha_v + zeros
+    beta_bc = beta_v + zeros
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(axis=0),
+        stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
+        atol=0.,
+        rtol=.05)
+    self.assertAllClose(
+        sample_values.var(axis=0),
+        stats.gamma.var(alpha_bc, scale=1 / beta_bc),
+        atol=10.0,
+        rtol=0.)
+    fails = 0
+    trials = 0
+    for ai, a in enumerate(np.reshape(alpha_v, [-1])):
+      for bi, b in enumerate(np.reshape(beta_v, [-1])):
+        s = sample_values[:, bi, ai]
+        trials += 1
+        fails += 0 if self._kstest(a, b, s) else 1
+    self.assertLess(fails, trials * 0.03)
 
   def _kstest(self, alpha, beta, samples):
     # Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -320,30 +302,29 @@
     return ks < 0.02
 
   def testGammaPdfOfSampleMultiDims(self):
-    with self.test_session():
-      gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
-      num = 50000
-      samples = gamma.sample(num, seed=137)
-      pdfs = gamma.prob(samples)
-      sample_vals, pdf_vals = self.evaluate([samples, pdfs])
-      self.assertEqual(samples.get_shape(), (num, 2, 2))
-      self.assertEqual(pdfs.get_shape(), (num, 2, 2))
-      self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
-      self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
-      self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
-      self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
-      if not stats:
-        return
-      self.assertAllClose(
-          stats.gamma.mean(
-              [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
-          sample_vals.mean(axis=0),
-          atol=.1)
-      self.assertAllClose(
-          stats.gamma.var([[7., 11.], [7., 11.]],
-                          scale=1 / np.array([[5., 5.], [6., 6.]])),
-          sample_vals.var(axis=0),
-          atol=.1)
+    gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
+    num = 50000
+    samples = gamma.sample(num, seed=137)
+    pdfs = gamma.prob(samples)
+    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
+    self.assertEqual(samples.get_shape(), (num, 2, 2))
+    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
+    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
+    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
+    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
+    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+    if not stats:
+      return
+    self.assertAllClose(
+        stats.gamma.mean([[7., 11.], [7., 11.]],
+                         scale=1 / np.array([[5., 5.], [6., 6.]])),
+        sample_vals.mean(axis=0),
+        atol=.1)
+    self.assertAllClose(
+        stats.gamma.var([[7., 11.], [7., 11.]],
+                        scale=1 / np.array([[5., 5.], [6., 6.]])),
+        sample_vals.var(axis=0),
+        atol=.1)
 
   def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
     s_p = zip(sample_vals, pdf_vals)
@@ -356,32 +337,29 @@
     self.assertNear(1., total, err=err)
 
   def testGammaNonPositiveInitializationParamsRaises(self):
-    with self.test_session():
-      alpha_v = constant_op.constant(0.0, name="alpha")
-      beta_v = constant_op.constant(1.0, name="beta")
-      with self.assertRaisesOpError("x > 0"):
-        gamma = gamma_lib.Gamma(concentration=alpha_v,
-                                rate=beta_v,
-                                validate_args=True)
-        self.evaluate(gamma.mean())
-      alpha_v = constant_op.constant(1.0, name="alpha")
-      beta_v = constant_op.constant(0.0, name="beta")
-      with self.assertRaisesOpError("x > 0"):
-        gamma = gamma_lib.Gamma(concentration=alpha_v,
-                                rate=beta_v,
-                                validate_args=True)
-        self.evaluate(gamma.mean())
+    alpha_v = constant_op.constant(0.0, name="alpha")
+    beta_v = constant_op.constant(1.0, name="beta")
+    with self.assertRaisesOpError("x > 0"):
+      gamma = gamma_lib.Gamma(
+          concentration=alpha_v, rate=beta_v, validate_args=True)
+      self.evaluate(gamma.mean())
+    alpha_v = constant_op.constant(1.0, name="alpha")
+    beta_v = constant_op.constant(0.0, name="beta")
+    with self.assertRaisesOpError("x > 0"):
+      gamma = gamma_lib.Gamma(
+          concentration=alpha_v, rate=beta_v, validate_args=True)
+      self.evaluate(gamma.mean())
 
   def testGammaWithSoftplusConcentrationRate(self):
-    with self.test_session():
-      alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
-      beta_v = constant_op.constant([1.0, -3.6], name="beta")
-      gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
-          concentration=alpha_v, rate=beta_v)
-      self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
-                          self.evaluate(gamma.concentration))
-      self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
-                          self.evaluate(gamma.rate))
+    alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
+    beta_v = constant_op.constant([1.0, -3.6], name="beta")
+    gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
+        concentration=alpha_v, rate=beta_v)
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(alpha_v)),
+        self.evaluate(gamma.concentration))
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
 
   def testGammaGammaKL(self):
     alpha0 = np.array([3.])
@@ -391,15 +369,14 @@
     beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
 
     # Build graph.
-    with self.test_session():
-      g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
-      g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
-      x = g0.sample(int(1e4), seed=0)
-      kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
-      kl_actual = kullback_leibler.kl_divergence(g0, g1)
+    g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
+    g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
+    x = g0.sample(int(1e4), seed=0)
+    kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
+    kl_actual = kullback_leibler.kl_divergence(g0, g1)
 
-      # Execute graph.
-      [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
+    # Execute graph.
+    [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
 
     self.assertEqual(beta0.shape, kl_actual.get_shape())
 
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 24b243f..630c2cb 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -21,7 +21,6 @@
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import tensor_shape
@@ -49,212 +48,198 @@
 class LaplaceTest(test.TestCase):
 
   def testLaplaceShape(self):
-    with self.test_session():
-      loc = constant_op.constant([3.0] * 5)
-      scale = constant_op.constant(11.0)
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    loc = constant_op.constant([3.0] * 5)
+    scale = constant_op.constant(11.0)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
-      self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
-      self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
+    self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
+    self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
 
   def testLaplaceLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
 
-      pdf = laplace.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testLaplaceLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([[2.0, 4.0]] * batch_size)
-      scale = constant_op.constant([[3.0, 4.0]] * batch_size)
-      loc_v = np.array([2.0, 4.0])
-      scale_v = np.array([3.0, 4.0])
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
+    batch_size = 6
+    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+    scale = constant_op.constant([[3.0, 4.0]] * batch_size)
+    loc_v = np.array([2.0, 4.0])
+    scale_v = np.array([3.0, 4.0])
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
 
-      pdf = laplace.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testLaplaceLogPDFMultidimensionalBroadcasting(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([[2.0, 4.0]] * batch_size)
-      scale = constant_op.constant(3.0)
-      loc_v = np.array([2.0, 4.0])
-      scale_v = 3.0
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
+    batch_size = 6
+    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+    scale = constant_op.constant(3.0)
+    loc_v = np.array([2.0, 4.0])
+    scale_v = 3.0
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
 
-      pdf = laplace.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testLaplaceCDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      cdf = laplace.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    cdf = laplace.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testLaplaceLogCDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      cdf = laplace.log_cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    cdf = laplace.log_cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testLaplaceLogSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      sf = laplace.log_survival_function(x)
-      self.assertEqual(sf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(sf), expected_sf)
+    sf = laplace.log_survival_function(x)
+    self.assertEqual(sf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(sf), expected_sf)
 
   def testLaplaceMean(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_means = stats.laplace.mean(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_means = stats.laplace.mean(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
 
   def testLaplaceMode(self):
-    with self.test_session():
-      loc_v = np.array([0.5, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
+    loc_v = np.array([0.5, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
 
   def testLaplaceVariance(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variances = stats.laplace.var(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variances = stats.laplace.var(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
 
   def testLaplaceStd(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.stddev().get_shape(), (3,))
-      if not stats:
-        return
-      expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.stddev().get_shape(), (3,))
+    if not stats:
+      return
+    expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
 
   def testLaplaceEntropy(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
 
   def testLaplaceSample(self):
-    with session.Session():
-      loc_v = 4.0
-      scale_v = 3.0
-      loc = constant_op.constant(loc_v)
-      scale = constant_op.constant(scale_v)
-      n = 100000
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      samples = laplace.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.laplace.mean(
-              loc_v, scale=scale_v),
-          rtol=0.05,
-          atol=0.)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.laplace.var(loc_v, scale=scale_v),
-          rtol=0.05,
-          atol=0.)
-      self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+    loc_v = 4.0
+    scale_v = 3.0
+    loc = constant_op.constant(loc_v)
+    scale = constant_op.constant(scale_v)
+    n = 100000
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    samples = laplace.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.laplace.mean(loc_v, scale=scale_v),
+        rtol=0.05,
+        atol=0.)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.laplace.var(loc_v, scale=scale_v),
+        rtol=0.05,
+        atol=0.)
+    self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
 
   def testLaplaceFullyReparameterized(self):
     loc = constant_op.constant(4.0)
@@ -269,39 +254,37 @@
     self.assertIsNotNone(grad_scale)
 
   def testLaplaceSampleMultiDimensional(self):
-    with session.Session():
-      loc_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
-      scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      n = 10000
-      samples = laplace.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n, 10, 100))
-      self.assertEqual(sample_values.shape, (n, 10, 100))
-      zeros = np.zeros_like(loc_v + scale_v)  # 10 x 100
-      loc_bc = loc_v + zeros
-      scale_bc = scale_v + zeros
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(axis=0),
-          stats.laplace.mean(
-              loc_bc, scale=scale_bc),
-          rtol=0.35,
-          atol=0.)
-      self.assertAllClose(
-          sample_values.var(axis=0),
-          stats.laplace.var(loc_bc, scale=scale_bc),
-          rtol=0.10,
-          atol=0.)
-      fails = 0
-      trials = 0
-      for ai, a in enumerate(np.reshape(loc_v, [-1])):
-        for bi, b in enumerate(np.reshape(scale_v, [-1])):
-          s = sample_values[:, bi, ai]
-          trials += 1
-          fails += 0 if self._kstest(a, b, s) else 1
-      self.assertLess(fails, trials * 0.03)
+    loc_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
+    scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    n = 10000
+    samples = laplace.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n, 10, 100))
+    self.assertEqual(sample_values.shape, (n, 10, 100))
+    zeros = np.zeros_like(loc_v + scale_v)  # 10 x 100
+    loc_bc = loc_v + zeros
+    scale_bc = scale_v + zeros
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(axis=0),
+        stats.laplace.mean(loc_bc, scale=scale_bc),
+        rtol=0.35,
+        atol=0.)
+    self.assertAllClose(
+        sample_values.var(axis=0),
+        stats.laplace.var(loc_bc, scale=scale_bc),
+        rtol=0.10,
+        atol=0.)
+    fails = 0
+    trials = 0
+    for ai, a in enumerate(np.reshape(loc_v, [-1])):
+      for bi, b in enumerate(np.reshape(scale_v, [-1])):
+        s = sample_values[:, bi, ai]
+        trials += 1
+        fails += 0 if self._kstest(a, b, s) else 1
+    self.assertLess(fails, trials * 0.03)
 
   def _kstest(self, loc, scale, samples):
     # Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -349,30 +332,26 @@
     self.assertNear(1., total, err=err)
 
   def testLaplaceNonPositiveInitializationParamsRaises(self):
-    with self.test_session():
-      loc_v = constant_op.constant(0.0, name="loc")
-      scale_v = constant_op.constant(-1.0, name="scale")
-      with self.assertRaisesOpError(
-          "Condition x > 0 did not hold element-wise"):
-        laplace = laplace_lib.Laplace(
-            loc=loc_v, scale=scale_v, validate_args=True)
-        self.evaluate(laplace.mean())
-      loc_v = constant_op.constant(1.0, name="loc")
-      scale_v = constant_op.constant(0.0, name="scale")
-      with self.assertRaisesOpError(
-          "Condition x > 0 did not hold element-wise"):
-        laplace = laplace_lib.Laplace(
-            loc=loc_v, scale=scale_v, validate_args=True)
-        self.evaluate(laplace.mean())
+    loc_v = constant_op.constant(0.0, name="loc")
+    scale_v = constant_op.constant(-1.0, name="scale")
+    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+      laplace = laplace_lib.Laplace(
+          loc=loc_v, scale=scale_v, validate_args=True)
+      self.evaluate(laplace.mean())
+    loc_v = constant_op.constant(1.0, name="loc")
+    scale_v = constant_op.constant(0.0, name="scale")
+    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+      laplace = laplace_lib.Laplace(
+          loc=loc_v, scale=scale_v, validate_args=True)
+      self.evaluate(laplace.mean())
 
   def testLaplaceWithSoftplusScale(self):
-    with self.test_session():
-      loc_v = constant_op.constant([0.0, 1.0], name="loc")
-      scale_v = constant_op.constant([-1.0, 2.0], name="scale")
-      laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
-      self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
+    loc_v = constant_op.constant([0.0, 1.0], name="loc")
+    scale_v = constant_op.constant([-1.0, 2.0], name="scale")
+    laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
+    self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index 7ff48c0..de73a40 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -61,16 +61,15 @@
     self.assertAllEqual(all_true, is_finite)
 
   def _testParamShapes(self, sample_shape, expected):
-    with self.test_session():
-      param_shapes = normal_lib.Normal.param_shapes(sample_shape)
-      mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
-      self.assertAllEqual(expected, self.evaluate(mu_shape))
-      self.assertAllEqual(expected, self.evaluate(sigma_shape))
-      mu = array_ops.zeros(mu_shape)
-      sigma = array_ops.ones(sigma_shape)
-      self.assertAllEqual(
-          expected,
-          self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
+    param_shapes = normal_lib.Normal.param_shapes(sample_shape)
+    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
+    self.assertAllEqual(expected, self.evaluate(mu_shape))
+    self.assertAllEqual(expected, self.evaluate(sigma_shape))
+    mu = array_ops.zeros(mu_shape)
+    sigma = array_ops.ones(sigma_shape)
+    self.assertAllEqual(
+        expected,
+        self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
 
   def _testParamStaticShapes(self, sample_shape, expected):
     param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
@@ -91,156 +90,150 @@
     self._testParamStaticShapes(
         tensor_shape.TensorShape(sample_shape), sample_shape)
 
-  @test_util.run_in_graph_and_eager_modes
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNormalWithSoftplusScale(self):
-    with self.test_session():
-      mu = array_ops.zeros((10, 3))
-      rho = array_ops.ones((10, 3)) * -2.
-      normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
-      self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
-      self.assertAllEqual(
-          self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
+    mu = array_ops.zeros((10, 3))
+    rho = array_ops.ones((10, 3)) * -2.
+    normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
+    self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      mu = constant_op.constant([3.0] * batch_size)
-      sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
-      x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    batch_size = 6
+    mu = constant_op.constant([3.0] * batch_size)
+    sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
+    x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      log_pdf = normal.log_prob(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(log_pdf).shape)
-      self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+    log_pdf = normal.log_prob(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(log_pdf).shape)
+    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
 
-      pdf = normal.prob(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(pdf).shape)
-      self.assertAllEqual(normal.batch_shape, pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
+    pdf = normal.prob(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(pdf).shape)
+    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
 
-      if not stats:
-        return
-      expected_log_pdf = stats.norm(self.evaluate(mu),
-                                    self.evaluate(sigma)).logpdf(x)
-      self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
-      self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.norm(self.evaluate(mu),
+                                  self.evaluate(sigma)).logpdf(x)
+    self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
+    self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      mu = constant_op.constant([[3.0, -3.0]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] *
-                                   batch_size)
-      x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    batch_size = 6
+    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
+    x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      log_pdf = normal.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(log_pdf).shape)
-      self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+    log_pdf = normal.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(log_pdf).shape)
+    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
 
-      pdf = normal.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
-      self.assertAllEqual(normal.batch_shape, pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+    pdf = normal.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
+    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, pdf_values.shape)
 
-      if not stats:
-        return
-      expected_log_pdf = stats.norm(self.evaluate(mu),
-                                    self.evaluate(sigma)).logpdf(x)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    if not stats:
+      return
+    expected_log_pdf = stats.norm(self.evaluate(mu),
+                                  self.evaluate(sigma)).logpdf(x)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalCDF(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      cdf = normal.cdf(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(cdf).shape)
-      self.assertAllEqual(normal.batch_shape, cdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
-      if not stats:
-        return
-      expected_cdf = stats.norm(mu, sigma).cdf(x)
-      self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    cdf = normal.cdf(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(cdf).shape)
+    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+    if not stats:
+      return
+    expected_cdf = stats.norm(mu, sigma).cdf(x)
+    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      sf = normal.survival_function(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(sf).shape)
-      self.assertAllEqual(normal.batch_shape, sf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
-      if not stats:
-        return
-      expected_sf = stats.norm(mu, sigma).sf(x)
-      self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
+    sf = normal.survival_function(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(sf).shape)
+    self.assertAllEqual(normal.batch_shape, sf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+    if not stats:
+      return
+    expected_sf = stats.norm(mu, sigma).sf(x)
+    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogCDF(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      cdf = normal.log_cdf(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(cdf).shape)
-      self.assertAllEqual(normal.batch_shape, cdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+    cdf = normal.log_cdf(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(cdf).shape)
+    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
 
-      if not stats:
-        return
-      expected_cdf = stats.norm(mu, sigma).logcdf(x)
-      self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
+    if not stats:
+      return
+    expected_cdf = stats.norm(mu, sigma).logcdf(x)
+    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
 
   def testFiniteGradientAtDifficultPoints(self):
     for dtype in [np.float32, np.float64]:
@@ -256,7 +249,7 @@
         ]:
           value = func(x)
           grads = gradients_impl.gradients(value, [mu, sigma])
-          with self.test_session(graph=g):
+          with self.session(graph=g):
             variables.global_variables_initializer().run()
             self.assertAllFinite(value)
             self.assertAllFinite(grads[0])
@@ -264,112 +257,106 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      sf = normal.log_survival_function(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(sf).shape)
-      self.assertAllEqual(normal.batch_shape, sf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+    sf = normal.log_survival_function(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(sf).shape)
+    self.assertAllEqual(normal.batch_shape, sf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
 
-      if not stats:
-        return
-      expected_sf = stats.norm(mu, sigma).logsf(x)
-      self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
+    if not stats:
+      return
+    expected_sf = stats.norm(mu, sigma).logsf(x)
+    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalEntropyWithScalarInputs(self):
     # Scipy.stats.norm cannot deal with the shapes in the other test.
-    with self.test_session():
-      mu_v = 2.34
-      sigma_v = 4.56
-      normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+    mu_v = 2.34
+    sigma_v = 4.56
+    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
 
-      entropy = normal.entropy()
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(entropy).shape)
-      self.assertAllEqual(normal.batch_shape, entropy.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
-      # scipy.stats.norm cannot deal with these shapes.
-      if not stats:
-        return
-      expected_entropy = stats.norm(mu_v, sigma_v).entropy()
-      self.assertAllClose(expected_entropy, self.evaluate(entropy))
+    entropy = normal.entropy()
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(entropy).shape)
+    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+    # scipy.stats.norm cannot deal with these shapes.
+    if not stats:
+      return
+    expected_entropy = stats.norm(mu_v, sigma_v).entropy()
+    self.assertAllClose(expected_entropy, self.evaluate(entropy))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalEntropy(self):
-    with self.test_session():
-      mu_v = np.array([1.0, 1.0, 1.0])
-      sigma_v = np.array([[1.0, 2.0, 3.0]]).T
-      normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+    mu_v = np.array([1.0, 1.0, 1.0])
+    sigma_v = np.array([[1.0, 2.0, 3.0]]).T
+    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
 
-      # scipy.stats.norm cannot deal with these shapes.
-      sigma_broadcast = mu_v * sigma_v
-      expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
-                                      2)
-      entropy = normal.entropy()
-      np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(entropy).shape)
-      self.assertAllEqual(normal.batch_shape, entropy.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+    # scipy.stats.norm cannot deal with these shapes.
+    sigma_broadcast = mu_v * sigma_v
+    expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
+    entropy = normal.entropy()
+    np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(entropy).shape)
+    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
 
-  @test_util.run_in_graph_and_eager_modes
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNormalMeanAndMode(self):
-    with self.test_session():
-      # Mu will be broadcast to [7, 7, 7].
-      mu = [7.]
-      sigma = [11., 12., 13.]
+    # Mu will be broadcast to [7, 7, 7].
+    mu = [7.]
+    sigma = [11., 12., 13.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.mean().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
+    self.assertAllEqual((3,), normal.mean().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
 
-      self.assertAllEqual((3,), normal.mode().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
+    self.assertAllEqual((3,), normal.mode().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalQuantile(self):
-    with self.test_session():
-      batch_size = 52
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
-      # Quantile performs piecewise rational approximation so adding some
-      # special input values to make sure we hit all the pieces.
-      p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
+    batch_size = 52
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
+    # Quantile performs piecewise rational approximation so adding some
+    # special input values to make sure we hit all the pieces.
+    p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      x = normal.quantile(p)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    x = normal.quantile(p)
 
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), x.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(x).shape)
-      self.assertAllEqual(normal.batch_shape, x.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), x.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(x).shape)
+    self.assertAllEqual(normal.batch_shape, x.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
 
-      if not stats:
-        return
-      expected_x = stats.norm(mu, sigma).ppf(p)
-      self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+    if not stats:
+      return
+    expected_x = stats.norm(mu, sigma).ppf(p)
+    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
 
   def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
     g = ops.Graph()
@@ -385,7 +372,7 @@
 
       value = dist.quantile(p)
       grads = gradients_impl.gradients(value, [mu, p])
-      with self.test_session(graph=g):
+      with self.cached_session(graph=g):
         variables.global_variables_initializer().run()
         self.assertAllFinite(grads[0])
         self.assertAllFinite(grads[1])
@@ -398,61 +385,58 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalVariance(self):
-    with self.test_session():
-      # sigma will be broadcast to [7, 7, 7]
-      mu = [1., 2., 3.]
-      sigma = [7.]
+    # sigma will be broadcast to [7, 7, 7]
+    mu = [1., 2., 3.]
+    sigma = [7.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.variance().get_shape())
-      self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
+    self.assertAllEqual((3,), normal.variance().get_shape())
+    self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalStandardDeviation(self):
-    with self.test_session():
-      # sigma will be broadcast to [7, 7, 7]
-      mu = [1., 2., 3.]
-      sigma = [7.]
+    # sigma will be broadcast to [7, 7, 7]
+    mu = [1., 2., 3.]
+    sigma = [7.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.stddev().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
+    self.assertAllEqual((3,), normal.stddev().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSample(self):
-    with self.test_session():
-      mu = constant_op.constant(3.0)
-      sigma = constant_op.constant(math.sqrt(3.0))
-      mu_v = 3.0
-      sigma_v = np.sqrt(3.0)
-      n = constant_op.constant(100000)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      samples = normal.sample(n)
-      sample_values = self.evaluate(samples)
-      # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
-      # The sample variance similarly is dependent on sigma and n.
-      # Thus, the tolerances below are very sensitive to number of samples
-      # as well as the variances chosen.
-      self.assertEqual(sample_values.shape, (100000,))
-      self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
-      self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
+    mu = constant_op.constant(3.0)
+    sigma = constant_op.constant(math.sqrt(3.0))
+    mu_v = 3.0
+    sigma_v = np.sqrt(3.0)
+    n = constant_op.constant(100000)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    samples = normal.sample(n)
+    sample_values = self.evaluate(samples)
+    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+    # The sample variance similarly is dependent on sigma and n.
+    # Thus, the tolerances below are very sensitive to number of samples
+    # as well as the variances chosen.
+    self.assertEqual(sample_values.shape, (100000,))
+    self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
+    self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
 
-      expected_samples_shape = tensor_shape.TensorShape(
-          [self.evaluate(n)]).concatenate(
-              tensor_shape.TensorShape(
-                  self.evaluate(normal.batch_shape_tensor())))
+    expected_samples_shape = tensor_shape.TensorShape(
+        [self.evaluate(n)]).concatenate(
+            tensor_shape.TensorShape(
+                self.evaluate(normal.batch_shape_tensor())))
 
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
-      expected_samples_shape = (
-          tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
-              normal.batch_shape))
+    expected_samples_shape = (
+        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+            normal.batch_shape))
 
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
   def testNormalFullyReparameterized(self):
     mu = constant_op.constant(4.0)
@@ -468,66 +452,63 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 2
-      mu = constant_op.constant([[3.0, -3.0]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] *
-                                   batch_size)
-      mu_v = [3.0, -3.0]
-      sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
-      n = constant_op.constant(100000)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      samples = normal.sample(n)
-      sample_values = self.evaluate(samples)
-      # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
-      # The sample variance similarly is dependent on sigma and n.
-      # Thus, the tolerances below are very sensitive to number of samples
-      # as well as the variances chosen.
-      self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
-      self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
+    batch_size = 2
+    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
+    mu_v = [3.0, -3.0]
+    sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
+    n = constant_op.constant(100000)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    samples = normal.sample(n)
+    sample_values = self.evaluate(samples)
+    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+    # The sample variance similarly is dependent on sigma and n.
+    # Thus, the tolerances below are very sensitive to number of samples
+    # as well as the variances chosen.
+    self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
+    self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
 
-      expected_samples_shape = tensor_shape.TensorShape(
-          [self.evaluate(n)]).concatenate(
-              tensor_shape.TensorShape(
-                  self.evaluate(normal.batch_shape_tensor())))
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    expected_samples_shape = tensor_shape.TensorShape(
+        [self.evaluate(n)]).concatenate(
+            tensor_shape.TensorShape(
+                self.evaluate(normal.batch_shape_tensor())))
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
-      expected_samples_shape = (
-          tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
-              normal.batch_shape))
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    expected_samples_shape = (
+        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+            normal.batch_shape))
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
   @test_util.run_in_graph_and_eager_modes
   def testNegativeSigmaFails(self):
-    with self.test_session():
-      with self.assertRaisesOpError("Condition x > 0 did not hold"):
-        normal = normal_lib.Normal(
-            loc=[1.], scale=[-5.], validate_args=True, name="G")
-        self.evaluate(normal.mean())
+    with self.assertRaisesOpError("Condition x > 0 did not hold"):
+      normal = normal_lib.Normal(
+          loc=[1.], scale=[-5.], validate_args=True, name="G")
+      self.evaluate(normal.mean())
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalShape(self):
-    with self.test_session():
-      mu = constant_op.constant([-3.0] * 5)
-      sigma = constant_op.constant(11.0)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    mu = constant_op.constant([-3.0] * 5)
+    sigma = constant_op.constant(11.0)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
-      self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
-      self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
+    self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
+    self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
 
   def testNormalShapeWithPlaceholders(self):
     mu = array_ops.placeholder(dtype=dtypes.float32)
     sigma = array_ops.placeholder(dtype=dtypes.float32)
     normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # get_batch_shape should return an "<unknown>" tensor.
       self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
       self.assertEqual(normal.event_shape, ())
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index a634194..cc43e12 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -92,22 +92,21 @@
   @test_util.run_in_graph_and_eager_modes
   def testNdtri(self):
     """Verifies that ndtri computation is correct."""
-    with self.test_session():
-      if not special:
-        return
+    if not special:
+      return
 
-      p = np.linspace(0., 1.0, 50).astype(np.float64)
-      # Quantile performs piecewise rational approximation so adding some
-      # special input values to make sure we hit all the pieces.
-      p = np.hstack((p, np.exp(-32), 1. - np.exp(-32),
-                     np.exp(-2), 1. - np.exp(-2)))
-      expected_x = special.ndtri(p)
-      x = special_math.ndtri(p)
-      self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+    p = np.linspace(0., 1.0, 50).astype(np.float64)
+    # Quantile performs piecewise rational approximation so adding some
+    # special input values to make sure we hit all the pieces.
+    p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
+                   1. - np.exp(-2)))
+    expected_x = special.ndtri(p)
+    x = special_math.ndtri(p)
+    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
 
   def testNdtriDynamicShape(self):
     """Verifies that ndtri computation is correct."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if not special:
         return
 
@@ -286,7 +285,7 @@
   def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
     raw_grid = _make_grid(dtype, grid_spec)
     grid = ops.convert_to_tensor(raw_grid)
-    with self.test_session():
+    with self.cached_session():
       fn = sm.log_ndtr if self._use_log else sm.ndtr
 
       # If there are N points in the grid,
@@ -355,7 +354,7 @@
 class ErfInvTest(test.TestCase):
 
   def testErfInvValues(self):
-    with self.test_session():
+    with self.cached_session():
       if not special:
         return
 
@@ -366,7 +365,7 @@
       self.assertAllClose(expected_x, x.eval(), atol=0.)
 
   def testErfInvIntegerInput(self):
-    with self.test_session():
+    with self.cached_session():
 
       with self.assertRaises(TypeError):
         x = np.array([1, 2, 3]).astype(np.int32)
@@ -397,7 +396,7 @@
     self.assertAllEqual(np.ones_like(x, dtype=np.bool), x)
 
   def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
-    with self.test_session():
+    with self.cached_session():
       grid = _make_grid(dtype, grid_spec)
       actual = sm.log_cdf_laplace(grid).eval()
 
@@ -439,7 +438,7 @@
         ErrorSpec(rtol=0.05, atol=0))
 
   def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
       # fine, but test to -200 anyways.
       grid = _make_grid(
@@ -458,7 +457,7 @@
       self.assertFalse(np.any(grad_ == 0))
 
   def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
       # fine, but test to -200 anyways.
       grid = _make_grid(
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
index 0559054..b34b538 100644
--- a/tensorflow/python/kernel_tests/distributions/student_t_test.py
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -50,100 +50,96 @@
 class StudentTTest(test.TestCase):
 
   def testStudentPDFAndLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([3.] * batch_size)
-      mu = constant_op.constant([7.] * batch_size)
-      sigma = constant_op.constant([8.] * batch_size)
-      df_v = 3.
-      mu_v = 7.
-      sigma_v = 8.
-      t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
-      student = student_t.StudentT(df, loc=mu, scale=-sigma)
+    batch_size = 6
+    df = constant_op.constant([3.] * batch_size)
+    mu = constant_op.constant([7.] * batch_size)
+    sigma = constant_op.constant([8.] * batch_size)
+    df_v = 3.
+    mu_v = 7.
+    sigma_v = 8.
+    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+    student = student_t.StudentT(df, loc=mu, scale=-sigma)
 
-      log_pdf = student.log_prob(t)
-      self.assertEquals(log_pdf.get_shape(), (6,))
-      log_pdf_values = self.evaluate(log_pdf)
-      pdf = student.prob(t)
-      self.assertEquals(pdf.get_shape(), (6,))
-      pdf_values = self.evaluate(pdf)
+    log_pdf = student.log_prob(t)
+    self.assertEquals(log_pdf.get_shape(), (6,))
+    log_pdf_values = self.evaluate(log_pdf)
+    pdf = student.prob(t)
+    self.assertEquals(pdf.get_shape(), (6,))
+    pdf_values = self.evaluate(pdf)
 
-      if not stats:
-        return
+    if not stats:
+      return
 
-      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
-      self.assertAllClose(expected_pdf, pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+    self.assertAllClose(expected_pdf, pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   def testStudentLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([[1.5, 7.2]] * batch_size)
-      mu = constant_op.constant([[3., -3.]] * batch_size)
-      sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
-                                   batch_size)
-      df_v = np.array([1.5, 7.2])
-      mu_v = np.array([3., -3.])
-      sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
-      t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
-      student = student_t.StudentT(df, loc=mu, scale=sigma)
-      log_pdf = student.log_prob(t)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = student.prob(t)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
+    batch_size = 6
+    df = constant_op.constant([[1.5, 7.2]] * batch_size)
+    mu = constant_op.constant([[3., -3.]] * batch_size)
+    sigma = constant_op.constant(
+        [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+    df_v = np.array([1.5, 7.2])
+    mu_v = np.array([3., -3.])
+    sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
+    t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
+    student = student_t.StudentT(df, loc=mu, scale=sigma)
+    log_pdf = student.log_prob(t)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = student.prob(t)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
-      self.assertAllClose(expected_pdf, pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    if not stats:
+      return
+    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+    self.assertAllClose(expected_pdf, pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   def testStudentCDFAndLogCDF(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([3.] * batch_size)
-      mu = constant_op.constant([7.] * batch_size)
-      sigma = constant_op.constant([-8.] * batch_size)
-      df_v = 3.
-      mu_v = 7.
-      sigma_v = 8.
-      t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
-      student = student_t.StudentT(df, loc=mu, scale=sigma)
+    batch_size = 6
+    df = constant_op.constant([3.] * batch_size)
+    mu = constant_op.constant([7.] * batch_size)
+    sigma = constant_op.constant([-8.] * batch_size)
+    df_v = 3.
+    mu_v = 7.
+    sigma_v = 8.
+    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+    student = student_t.StudentT(df, loc=mu, scale=sigma)
 
-      log_cdf = student.log_cdf(t)
-      self.assertEquals(log_cdf.get_shape(), (6,))
-      log_cdf_values = self.evaluate(log_cdf)
-      cdf = student.cdf(t)
-      self.assertEquals(cdf.get_shape(), (6,))
-      cdf_values = self.evaluate(cdf)
+    log_cdf = student.log_cdf(t)
+    self.assertEquals(log_cdf.get_shape(), (6,))
+    log_cdf_values = self.evaluate(log_cdf)
+    cdf = student.cdf(t)
+    self.assertEquals(cdf.get_shape(), (6,))
+    cdf_values = self.evaluate(cdf)
 
-      if not stats:
-        return
-      expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(
-          np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(
-          np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
+    if not stats:
+      return
+    expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(
+        np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(
+        np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
 
   def testStudentEntropy(self):
     df_v = np.array([[2., 3., 7.]])  # 1x3
     mu_v = np.array([[1., -1, 0]])  # 1x3
     sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
-    with self.test_session():
-      student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
-      ent = student.entropy()
-      ent_values = self.evaluate(ent)
+    student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
+    ent = student.entropy()
+    ent_values = self.evaluate(ent)
 
     # Help scipy broadcast to 3x3
     ones = np.array([[1, 1, 1]])
@@ -160,90 +156,81 @@
     self.assertAllClose(expected_entropy, ent_values)
 
   def testStudentSample(self):
-    with self.test_session():
-      df = constant_op.constant(4.)
-      mu = constant_op.constant(3.)
-      sigma = constant_op.constant(-math.sqrt(10.))
-      df_v = 4.
-      mu_v = 3.
-      sigma_v = np.sqrt(10.)
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      n_val = 200000
-      self.assertEqual(sample_values.shape, (n_val,))
-      self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values.var(),
-          sigma_v**2 * df_v / (df_v - 2),
-          rtol=0.1,
-          atol=0)
-      self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
+    df = constant_op.constant(4.)
+    mu = constant_op.constant(3.)
+    sigma = constant_op.constant(-math.sqrt(10.))
+    df_v = 4.
+    mu_v = 3.
+    sigma_v = np.sqrt(10.)
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    n_val = 200000
+    self.assertEqual(sample_values.shape, (n_val,))
+    self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
+    self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
 
   # Test that sampling with the same seed twice gives the same results.
   def testStudentSampleMultipleTimes(self):
-    with self.test_session():
-      df = constant_op.constant(4.)
-      mu = constant_op.constant(3.)
-      sigma = constant_op.constant(math.sqrt(10.))
-      n = constant_op.constant(100)
+    df = constant_op.constant(4.)
+    mu = constant_op.constant(3.)
+    sigma = constant_op.constant(math.sqrt(10.))
+    n = constant_op.constant(100)
 
-      random_seed.set_random_seed(654321)
-      student = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, name="student_t1")
-      samples1 = self.evaluate(student.sample(n, seed=123456))
+    random_seed.set_random_seed(654321)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
+    samples1 = self.evaluate(student.sample(n, seed=123456))
 
-      random_seed.set_random_seed(654321)
-      student2 = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, name="student_t2")
-      samples2 = self.evaluate(student2.sample(n, seed=123456))
+    random_seed.set_random_seed(654321)
+    student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
+    samples2 = self.evaluate(student2.sample(n, seed=123456))
 
-      self.assertAllClose(samples1, samples2)
+    self.assertAllClose(samples1, samples2)
 
   def testStudentSampleSmallDfNoNan(self):
-    with self.test_session():
-      df_v = [1e-1, 1e-5, 1e-10, 1e-20]
-      df = constant_op.constant(df_v)
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=1., scale=1.)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      n_val = 200000
-      self.assertEqual(sample_values.shape, (n_val, 4))
-      self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
+    df_v = [1e-1, 1e-5, 1e-10, 1e-20]
+    df = constant_op.constant(df_v)
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=1., scale=1.)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    n_val = 200000
+    self.assertEqual(sample_values.shape, (n_val, 4))
+    self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
 
   def testStudentSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 7
-      df = constant_op.constant([[5., 7.]] * batch_size)
-      mu = constant_op.constant([[3., -3.]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
-                                   batch_size)
-      df_v = [5., 7.]
-      mu_v = [3., -3.]
-      sigma_v = [np.sqrt(10.), np.sqrt(15.)]
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
-      self.assertAllClose(
-          sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values[:, 0, 0].var(),
-          sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
-          rtol=0.2,
-          atol=0)
-      self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
-      self.assertAllClose(
-          sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values[:, 0, 1].var(),
-          sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
-          rtol=0.2,
-          atol=0)
-      self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
+    batch_size = 7
+    df = constant_op.constant([[5., 7.]] * batch_size)
+    mu = constant_op.constant([[3., -3.]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+    df_v = [5., 7.]
+    mu_v = [3., -3.]
+    sigma_v = [np.sqrt(10.), np.sqrt(15.)]
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
+    self.assertAllClose(
+        sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values[:, 0, 0].var(),
+        sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
+        rtol=0.2,
+        atol=0)
+    self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
+    self.assertAllClose(
+        sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values[:, 0, 1].var(),
+        sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
+        rtol=0.2,
+        atol=0)
+    self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
 
   def _checkKLApprox(self, df, mu, sigma, samples):
     n = samples.size
@@ -325,114 +312,102 @@
     _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
 
   def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
-    with self.test_session():
-      mu = [1., 3.3, 4.4]
-      student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
-      mean = self.evaluate(student.mean())
-      self.assertAllClose([1., 3.3, 4.4], mean)
+    mu = [1., 3.3, 4.4]
+    student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
+    mean = self.evaluate(student.mean())
+    self.assertAllClose([1., 3.3, 4.4], mean)
 
   def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
-    with self.test_session():
-      mu = [1., 3.3, 4.4]
-      student = student_t.StudentT(
-          df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
-          allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.mean())
+    mu = [1., 3.3, 4.4]
+    student = student_t.StudentT(
+        df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.mean())
 
   def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
-    with self.test_session():
-      mu = [-2, 0., 1., 3.3, 4.4]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(
-          df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
-          allow_nan_stats=True)
-      mean = self.evaluate(student.mean())
-      self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
+    mu = [-2, 0., 1., 3.3, 4.4]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(
+        df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
+    mean = self.evaluate(student.mean())
+    self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
 
   def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
-    with self.test_session():
-      # df = 0.5 ==> undefined mean ==> undefined variance.
-      # df = 1.5 ==> infinite variance.
-      df = [0.5, 1.5, 3., 5., 7.]
-      mu = [-2, 0., 1., 3.3, 4.4]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, allow_nan_stats=True)
-      var = self.evaluate(student.variance())
-      ## scipy uses inf for variance when the mean is undefined.  When mean is
-      # undefined we say variance is undefined as well.  So test the first
-      # member of var, making sure it is NaN, then replace with inf and compare
-      # to scipy.
-      self.assertTrue(np.isnan(var[0]))
-      var[0] = np.inf
+    # df = 0.5 ==> undefined mean ==> undefined variance.
+    # df = 1.5 ==> infinite variance.
+    df = [0.5, 1.5, 3., 5., 7.]
+    mu = [-2, 0., 1., 3.3, 4.4]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(
+        df=df, loc=mu, scale=sigma, allow_nan_stats=True)
+    var = self.evaluate(student.variance())
+    ## scipy uses inf for variance when the mean is undefined.  When mean is
+    # undefined we say variance is undefined as well.  So test the first
+    # member of var, making sure it is NaN, then replace with inf and compare
+    # to scipy.
+    self.assertTrue(np.isnan(var[0]))
+    var[0] = np.inf
 
-      if not stats:
-        return
-      expected_var = [
-          stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_var, var)
+    if not stats:
+      return
+    expected_var = [
+        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_var, var)
 
   def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
       self):
-    with self.test_session():
-      # df = 1.5 ==> infinite variance.
-      df = [1.5, 3., 5., 7.]
-      mu = [0., 1., 3.3, 4.4]
-      sigma = [4., 3., 2., 1.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      var = self.evaluate(student.variance())
+    # df = 1.5 ==> infinite variance.
+    df = [1.5, 3., 5., 7.]
+    mu = [0., 1., 3.3, 4.4]
+    sigma = [4., 3., 2., 1.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    var = self.evaluate(student.variance())
 
-      if not stats:
-        return
-      expected_var = [
-          stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_var, var)
+    if not stats:
+      return
+    expected_var = [
+        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_var, var)
 
   def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
-    with self.test_session():
-      # df <= 1 ==> variance not defined
-      student = student_t.StudentT(
-          df=1., loc=0., scale=1., allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.variance())
+    # df <= 1 ==> variance not defined
+    student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.variance())
 
-    with self.test_session():
-      # df <= 1 ==> variance not defined
-      student = student_t.StudentT(
-          df=0.5, loc=0., scale=1., allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.variance())
+    # df <= 1 ==> variance not defined
+    student = student_t.StudentT(
+        df=0.5, loc=0., scale=1., allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.variance())
 
   def testStd(self):
-    with self.test_session():
-      # Defined for all batch members.
-      df = [3.5, 5., 3., 5., 7.]
-      mu = [-2.2]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      # Test broadcast of mu across shape of df/sigma
-      stddev = self.evaluate(student.stddev())
-      mu *= len(df)
+    # Defined for all batch members.
+    df = [3.5, 5., 3., 5., 7.]
+    mu = [-2.2]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    # Test broadcast of mu across shape of df/sigma
+    stddev = self.evaluate(student.stddev())
+    mu *= len(df)
 
-      if not stats:
-        return
-      expected_stddev = [
-          stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_stddev, stddev)
+    if not stats:
+      return
+    expected_stddev = [
+        stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_stddev, stddev)
 
   def testMode(self):
-    with self.test_session():
-      df = [0.5, 1., 3]
-      mu = [-1, 0., 1]
-      sigma = [5., 4., 3.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      # Test broadcast of mu across shape of df/sigma
-      mode = self.evaluate(student.mode())
-      self.assertAllClose([-1., 0, 1], mode)
+    df = [0.5, 1., 3]
+    mu = [-1, 0., 1]
+    sigma = [5., 4., 3.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    # Test broadcast of mu across shape of df/sigma
+    mode = self.evaluate(student.mode())
+    self.assertAllClose([-1., 0, 1], mode)
 
   def testPdfOfSample(self):
     student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
@@ -510,25 +485,23 @@
     self.assertNear(1., total, err=err)
 
   def testNegativeDofFails(self):
-    with self.test_session():
-      with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
-        student = student_t.StudentT(
-            df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
-        self.evaluate(student.mean())
+    with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
+      student = student_t.StudentT(
+          df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
+      self.evaluate(student.mean())
 
   def testStudentTWithAbsDfSoftplusScale(self):
-    with self.test_session():
-      df = constant_op.constant([-3.2, -4.6])
-      mu = constant_op.constant([-4.2, 3.4])
-      sigma = constant_op.constant([-6.4, -8.8])
-      student = student_t.StudentTWithAbsDfSoftplusScale(
-          df=df, loc=mu, scale=sigma)
-      self.assertAllClose(
-          math_ops.floor(self.evaluate(math_ops.abs(df))),
-          self.evaluate(student.df))
-      self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
+    df = constant_op.constant([-3.2, -4.6])
+    mu = constant_op.constant([-4.2, 3.4])
+    sigma = constant_op.constant([-6.4, -8.8])
+    student = student_t.StudentTWithAbsDfSoftplusScale(
+        df=df, loc=mu, scale=sigma)
+    self.assertAllClose(
+        math_ops.floor(self.evaluate(math_ops.abs(df))),
+        self.evaluate(student.df))
+    self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index bc9c267..9cdcd36 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -50,255 +50,239 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformRange(self):
-    with self.test_session():
-      a = 3.0
-      b = 10.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      self.assertAllClose(a, self.evaluate(uniform.low))
-      self.assertAllClose(b, self.evaluate(uniform.high))
-      self.assertAllClose(b - a, self.evaluate(uniform.range()))
+    a = 3.0
+    b = 10.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    self.assertAllClose(a, self.evaluate(uniform.low))
+    self.assertAllClose(b, self.evaluate(uniform.high))
+    self.assertAllClose(b - a, self.evaluate(uniform.range()))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformPDF(self):
-    with self.test_session():
-      a = constant_op.constant([-3.0] * 5 + [15.0])
-      b = constant_op.constant([11.0] * 5 + [20.0])
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([-3.0] * 5 + [15.0])
+    b = constant_op.constant([11.0] * 5 + [20.0])
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      a_v = -3.0
-      b_v = 11.0
-      x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
+    a_v = -3.0
+    b_v = 11.0
+    x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
 
-      def _expected_pdf():
-        pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
-        pdf[x > b_v] = 0.0
-        pdf[x < a_v] = 0.0
-        pdf[5] = 1.0 / (20.0 - 15.0)
-        return pdf
+    def _expected_pdf():
+      pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
+      pdf[x > b_v] = 0.0
+      pdf[x < a_v] = 0.0
+      pdf[5] = 1.0 / (20.0 - 15.0)
+      return pdf
 
-      expected_pdf = _expected_pdf()
+    expected_pdf = _expected_pdf()
 
-      pdf = uniform.prob(x)
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(x)
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
-      log_pdf = uniform.log_prob(x)
-      self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
+    log_pdf = uniform.log_prob(x)
+    self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformShape(self):
-    with self.test_session():
-      a = constant_op.constant([-3.0] * 5)
-      b = constant_op.constant(11.0)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([-3.0] * 5)
+    b = constant_op.constant(11.0)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
-      self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
-      self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
+    self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
+    self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformPDFWithScalarEndpoint(self):
-    with self.test_session():
-      a = constant_op.constant([0.0, 5.0])
-      b = constant_op.constant(10.0)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([0.0, 5.0])
+    b = constant_op.constant(10.0)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      x = np.array([0.0, 8.0], dtype=np.float32)
-      expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
+    x = np.array([0.0, 8.0], dtype=np.float32)
+    expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
 
-      pdf = uniform.prob(x)
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(x)
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformCDF(self):
-    with self.test_session():
-      batch_size = 6
-      a = constant_op.constant([1.0] * batch_size)
-      b = constant_op.constant([11.0] * batch_size)
-      a_v = 1.0
-      b_v = 11.0
-      x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
+    batch_size = 6
+    a = constant_op.constant([1.0] * batch_size)
+    b = constant_op.constant([11.0] * batch_size)
+    a_v = 1.0
+    b_v = 11.0
+    x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
 
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      def _expected_cdf():
-        cdf = (x - a_v) / (b_v - a_v)
-        cdf[x >= b_v] = 1
-        cdf[x < a_v] = 0
-        return cdf
+    def _expected_cdf():
+      cdf = (x - a_v) / (b_v - a_v)
+      cdf[x >= b_v] = 1
+      cdf[x < a_v] = 0
+      return cdf
 
-      cdf = uniform.cdf(x)
-      self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
+    cdf = uniform.cdf(x)
+    self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
 
-      log_cdf = uniform.log_cdf(x)
-      self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
+    log_cdf = uniform.log_cdf(x)
+    self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformEntropy(self):
-    with self.test_session():
-      a_v = np.array([1.0, 1.0, 1.0])
-      b_v = np.array([[1.5, 2.0, 3.0]])
-      uniform = uniform_lib.Uniform(low=a_v, high=b_v)
+    a_v = np.array([1.0, 1.0, 1.0])
+    b_v = np.array([[1.5, 2.0, 3.0]])
+    uniform = uniform_lib.Uniform(low=a_v, high=b_v)
 
-      expected_entropy = np.log(b_v - a_v)
-      self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
+    expected_entropy = np.log(b_v - a_v)
+    self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformAssertMaxGtMin(self):
-    with self.test_session():
-      a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
-      b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+    a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
+    b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
 
-      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
-                                               "x < y"):
-        uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
-        self.evaluate(uniform.low)
+    with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+                                             "x < y"):
+      uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
+      self.evaluate(uniform.low)
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSample(self):
-    with self.test_session():
-      a = constant_op.constant([3.0, 4.0])
-      b = constant_op.constant(13.0)
-      a1_v = 3.0
-      a2_v = 4.0
-      b_v = 13.0
-      n = constant_op.constant(100000)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([3.0, 4.0])
+    b = constant_op.constant(13.0)
+    a1_v = 3.0
+    a2_v = 4.0
+    b_v = 13.0
+    n = constant_op.constant(100000)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      samples = uniform.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertAllClose(
-          sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
-      self.assertAllClose(
-          sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
-      self.assertFalse(
-          np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
-      self.assertFalse(
-          np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
+    samples = uniform.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertAllClose(
+        sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
+    self.assertAllClose(
+        sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
+    self.assertFalse(
+        np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
+    self.assertFalse(
+        np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
 
   @test_util.run_in_graph_and_eager_modes
   def _testUniformSampleMultiDimensional(self):
     # DISABLED: Please enable this test once b/issues/30149644 is resolved.
-    with self.test_session():
-      batch_size = 2
-      a_v = [3.0, 22.0]
-      b_v = [13.0, 35.0]
-      a = constant_op.constant([a_v] * batch_size)
-      b = constant_op.constant([b_v] * batch_size)
+    batch_size = 2
+    a_v = [3.0, 22.0]
+    b_v = [13.0, 35.0]
+    a = constant_op.constant([a_v] * batch_size)
+    b = constant_op.constant([b_v] * batch_size)
 
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      n_v = 100000
-      n = constant_op.constant(n_v)
-      samples = uniform.sample(n)
-      self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
+    n_v = 100000
+    n = constant_op.constant(n_v)
+    samples = uniform.sample(n)
+    self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
 
-      sample_values = self.evaluate(samples)
+    sample_values = self.evaluate(samples)
 
-      self.assertFalse(
-          np.any(sample_values[:, 0, 0] < a_v[0]) or
-          np.any(sample_values[:, 0, 0] >= b_v[0]))
-      self.assertFalse(
-          np.any(sample_values[:, 0, 1] < a_v[1]) or
-          np.any(sample_values[:, 0, 1] >= b_v[1]))
+    self.assertFalse(
+        np.any(sample_values[:, 0, 0] < a_v[0]) or
+        np.any(sample_values[:, 0, 0] >= b_v[0]))
+    self.assertFalse(
+        np.any(sample_values[:, 0, 1] < a_v[1]) or
+        np.any(sample_values[:, 0, 1] >= b_v[1]))
 
-      self.assertAllClose(
-          sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
-      self.assertAllClose(
-          sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
+    self.assertAllClose(
+        sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
+    self.assertAllClose(
+        sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformMean(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformVariance(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformStd(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformNans(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 100.0]
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = 10.0
+    b = [11.0, 100.0]
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      no_nans = constant_op.constant(1.0)
-      nans = constant_op.constant(0.0) / constant_op.constant(0.0)
-      self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
-      with_nans = array_ops.stack([no_nans, nans])
+    no_nans = constant_op.constant(1.0)
+    nans = constant_op.constant(0.0) / constant_op.constant(0.0)
+    self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
+    with_nans = array_ops.stack([no_nans, nans])
 
-      pdf = uniform.prob(with_nans)
+    pdf = uniform.prob(with_nans)
 
-      is_nan = self.evaluate(math_ops.is_nan(pdf))
-      self.assertFalse(is_nan[0])
-      self.assertTrue(is_nan[1])
+    is_nan = self.evaluate(math_ops.is_nan(pdf))
+    self.assertFalse(is_nan[0])
+    self.assertTrue(is_nan[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSamplePdf(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 100.0]
-      uniform = uniform_lib.Uniform(a, b)
-      self.assertTrue(
-          self.evaluate(
-              math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
+    a = 10.0
+    b = [11.0, 100.0]
+    uniform = uniform_lib.Uniform(a, b)
+    self.assertTrue(
+        self.evaluate(
+            math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformBroadcasting(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 20.0]
-      uniform = uniform_lib.Uniform(a, b)
+    a = 10.0
+    b = [11.0, 20.0]
+    uniform = uniform_lib.Uniform(a, b)
 
-      pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
-      expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
+    expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSampleWithShape(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 20.0]
-      uniform = uniform_lib.Uniform(a, b)
+    a = 10.0
+    b = [11.0, 20.0]
+    uniform = uniform_lib.Uniform(a, b)
 
-      pdf = uniform.prob(uniform.sample((2, 3)))
-      # pylint: disable=bad-continuation
-      expected_pdf = [
-          [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
-          [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
-      ]
-      # pylint: enable=bad-continuation
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(uniform.sample((2, 3)))
+    # pylint: disable=bad-continuation
+    expected_pdf = [
+        [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+        [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+    ]
+    # pylint: enable=bad-continuation
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
-      pdf = uniform.prob(uniform.sample())
-      expected_pdf = [1.0, 0.1]
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(uniform.sample())
+    expected_pdf = [1.0, 0.1]
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   def testFullyReparameterized(self):
     a = constant_op.constant(0.1)
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 61faa84..27d652c 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -69,7 +69,7 @@
     w = array_ops.placeholder(dtypes.float32)
     feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
                  z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
-    with self.test_session():
+    with self.cached_session():
       with ops.control_dependencies([du.assert_integer_form(x)]):
         array_ops.identity(x).eval(feed_dict=feed_dict)
 
@@ -122,58 +122,52 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testImproperArguments(self):
-    with self.test_session():
-      with self.assertRaises(ValueError):
-        du.get_logits_and_probs(logits=None, probs=None)
+    with self.assertRaises(ValueError):
+      du.get_logits_and_probs(logits=None, probs=None)
 
-      with self.assertRaises(ValueError):
-        du.get_logits_and_probs(logits=[0.1], probs=[0.1])
+    with self.assertRaises(ValueError):
+      du.get_logits_and_probs(logits=[0.1], probs=[0.1])
 
   @test_util.run_in_graph_and_eager_modes
   def testLogits(self):
     p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
     logits = _logit(p)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          logits=logits, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        logits=logits, validate_args=True)
 
-      self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
-      self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
+    self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
+    self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
 
   @test_util.run_in_graph_and_eager_modes
   def testLogitsMultidimensional(self):
     p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
     logits = np.log(p)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          logits=logits, multidimensional=True, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        logits=logits, multidimensional=True, validate_args=True)
 
-      self.assertAllClose(self.evaluate(new_p), p)
-      self.assertAllClose(self.evaluate(new_logits), logits)
+    self.assertAllClose(self.evaluate(new_p), p)
+    self.assertAllClose(self.evaluate(new_logits), logits)
 
   @test_util.run_in_graph_and_eager_modes
   def testProbability(self):
     p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          probs=p, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
 
-      self.assertAllClose(_logit(p), self.evaluate(new_logits))
-      self.assertAllClose(p, self.evaluate(new_p))
+    self.assertAllClose(_logit(p), self.evaluate(new_logits))
+    self.assertAllClose(p, self.evaluate(new_p))
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityMultidimensional(self):
     p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          probs=p, multidimensional=True, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        probs=p, multidimensional=True, validate_args=True)
 
-      self.assertAllClose(np.log(p), self.evaluate(new_logits))
-      self.assertAllClose(p, self.evaluate(new_p))
+    self.assertAllClose(np.log(p), self.evaluate(new_logits))
+    self.assertAllClose(p, self.evaluate(new_p))
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityValidateArgs(self):
@@ -183,28 +177,22 @@
     # Component greater than 1.
     p3 = [2, 0.2, 0.5, 0.3, .2]
 
-    with self.test_session():
-      _, prob = du.get_logits_and_probs(
-          probs=p, validate_args=True)
+    _, prob = du.get_logits_and_probs(probs=p, validate_args=True)
+    self.evaluate(prob)
+
+    with self.assertRaisesOpError("Condition x >= 0"):
+      _, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("Condition x >= 0"):
-        _, prob = du.get_logits_and_probs(
-            probs=p2, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
+    self.evaluate(prob)
 
-      _, prob = du.get_logits_and_probs(
-          probs=p2, validate_args=False)
+    with self.assertRaisesOpError("probs has components greater than 1"):
+      _, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("probs has components greater than 1"):
-        _, prob = du.get_logits_and_probs(
-            probs=p3, validate_args=True)
-        self.evaluate(prob)
-
-      _, prob = du.get_logits_and_probs(
-          probs=p3, validate_args=False)
-      self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
+    self.evaluate(prob)
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityValidateArgsMultidimensional(self):
@@ -216,41 +204,39 @@
     # Does not sum to 1.
     p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
 
-    with self.test_session():
+    _, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
+    self.evaluate(prob)
+
+    with self.assertRaisesOpError("Condition x >= 0"):
       _, prob = du.get_logits_and_probs(
-          probs=p, multidimensional=True)
+          probs=p2, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("Condition x >= 0"):
-        _, prob = du.get_logits_and_probs(
-            probs=p2, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p2, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
+    with self.assertRaisesOpError(
+        "(probs has components greater than 1|probs does not sum to 1)"):
       _, prob = du.get_logits_and_probs(
-          probs=p2, multidimensional=True, validate_args=False)
+          probs=p3, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError(
-          "(probs has components greater than 1|probs does not sum to 1)"):
-        _, prob = du.get_logits_and_probs(
-            probs=p3, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p3, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
+    with self.assertRaisesOpError("probs does not sum to 1"):
       _, prob = du.get_logits_and_probs(
-          probs=p3, multidimensional=True, validate_args=False)
+          probs=p4, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("probs does not sum to 1"):
-        _, prob = du.get_logits_and_probs(
-            probs=p4, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
-
-      _, prob = du.get_logits_and_probs(
-          probs=p4, multidimensional=True, validate_args=False)
-      self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p4, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
   def testProbsMultidimShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         p = array_ops.ones([int(2**11+1)], dtype=np.float16)
         du.get_logits_and_probs(
@@ -264,7 +250,7 @@
         prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
 
   def testLogitsMultidimShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         l = array_ops.ones([int(2**11+1)], dtype=np.float16)
         du.get_logits_and_probs(
@@ -281,7 +267,7 @@
 class EmbedCheckCategoricalEventShapeTest(test.TestCase):
 
   def testTooSmall(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         param = array_ops.ones([1], dtype=np.float16)
         checked_param = du.embed_check_categorical_event_shape(
@@ -295,7 +281,7 @@
         checked_param.eval(feed_dict={param: np.ones([1])})
 
   def testTooLarge(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
         checked_param = du.embed_check_categorical_event_shape(
@@ -310,18 +296,17 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testUnsupportedDtype(self):
-    with self.test_session():
-      param = ops.convert_to_tensor(
-          np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
-          dtype=dtypes.qint16)
-      with self.assertRaises(TypeError):
-        du.embed_check_categorical_event_shape(param)
+    param = ops.convert_to_tensor(
+        np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
+        dtype=dtypes.qint16)
+    with self.assertRaises(TypeError):
+      du.embed_check_categorical_event_shape(param)
 
 
 class EmbedCheckIntegerCastingClosedTest(test.TestCase):
 
   def testCorrectlyAssertsNonnegative(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements must be non-negative"):
         x = array_ops.placeholder(dtype=dtypes.float16)
         x_checked = du.embed_check_integer_casting_closed(
@@ -329,7 +314,7 @@
         x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
 
   def testCorrectlyAssersIntegerForm(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements must be int16-equivalent."):
         x = array_ops.placeholder(dtype=dtypes.float16)
         x_checked = du.embed_check_integer_casting_closed(
@@ -337,7 +322,7 @@
         x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
 
   def testCorrectlyAssertsLargestPossibleInteger(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements cannot exceed 32767."):
         x = array_ops.placeholder(dtype=dtypes.int32)
         x_checked = du.embed_check_integer_casting_closed(
@@ -345,7 +330,7 @@
         x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
 
   def testCorrectlyAssertsSmallestPossibleInteger(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements cannot be smaller than 0."):
         x = array_ops.placeholder(dtype=dtypes.int32)
         x_checked = du.embed_check_integer_casting_closed(
@@ -365,29 +350,27 @@
 
     log_combs = np.log(special.binom(n, k))
 
-    with self.test_session():
-      n = np.array(n, dtype=np.float32)
-      counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
-      log_binom = du.log_combinations(n, counts)
-      self.assertEqual([4], log_binom.get_shape())
-      self.assertAllClose(log_combs, self.evaluate(log_binom))
+    n = np.array(n, dtype=np.float32)
+    counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
+    log_binom = du.log_combinations(n, counts)
+    self.assertEqual([4], log_binom.get_shape())
+    self.assertAllClose(log_combs, self.evaluate(log_binom))
 
   def testLogCombinationsShape(self):
     # Shape [2, 2]
     n = [[2, 5], [12, 15]]
 
-    with self.test_session():
-      n = np.array(n, dtype=np.float32)
-      # Shape [2, 2, 4]
-      counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
-      log_binom = du.log_combinations(n, counts)
-      self.assertEqual([2, 2], log_binom.get_shape())
+    n = np.array(n, dtype=np.float32)
+    # Shape [2, 2, 4]
+    counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
+    log_binom = du.log_combinations(n, counts)
+    self.assertEqual([2, 2], log_binom.get_shape())
 
 
 class DynamicShapeTest(test.TestCase):
 
   def testSameDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       scalar = constant_op.constant(2.0)
       scalar1 = array_ops.placeholder(dtype=dtypes.float32)
 
@@ -497,22 +480,21 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testRollStatic(self):
-    with self.test_session():
-      if context.executing_eagerly():
-        error_message = r"Attempt to convert a value \(None\)"
-      else:
-        error_message = "None values not supported."
-      with self.assertRaisesRegexp(ValueError, error_message):
-        du.rotate_transpose(None, 1)
-      for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
-        for shift in np.arange(-5, 5):
-          y = du.rotate_transpose(x, shift)
-          self.assertAllEqual(
-              self._np_rotate_transpose(x, shift), self.evaluate(y))
-          self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
+    if context.executing_eagerly():
+      error_message = r"Attempt to convert a value \(None\)"
+    else:
+      error_message = "None values not supported."
+    with self.assertRaisesRegexp(ValueError, error_message):
+      du.rotate_transpose(None, 1)
+    for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
+      for shift in np.arange(-5, 5):
+        y = du.rotate_transpose(x, shift)
+        self.assertAllEqual(
+            self._np_rotate_transpose(x, shift), self.evaluate(y))
+        self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
 
   def testRollDynamic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       shift = array_ops.placeholder(dtypes.int32)
       for x_value in (np.ones(
@@ -530,7 +512,7 @@
 class PickVectorTest(test.TestCase):
 
   def testCorrectlyPicksVector(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.arange(10, 12)
       y = np.arange(15, 18)
       self.assertAllEqual(
@@ -568,19 +550,19 @@
   def testDynamicRankEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicRankEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
 
   def testDynamicRankEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
 
 
@@ -607,19 +589,19 @@
   def testDynamicShapeEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicShapeEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
 
   def testDynamicShapeEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
 
 
@@ -646,20 +628,20 @@
   def testDynamicValueEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.zeros((2, 3)),
                           value.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicValueEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
 
   def testDynamicValueEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
 
 
@@ -691,7 +673,7 @@
 
   def _run_test(self, x_, use_deferred_shape=False, **kwargs):
     x_ = np.asarray(x_)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       static_shape = None if use_deferred_shape else x_.shape
       x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
       # Add `zeros_like(x)` such that x's value and gradient are identical. We
@@ -761,7 +743,7 @@
 
   def _run_test(self, x_, use_deferred_shape=False, **kwargs):
     x_ = np.asarray(x_)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       static_shape = None if use_deferred_shape else x_.shape
       x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
       zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
@@ -795,7 +777,7 @@
     logx_ = np.array([[0., -1, 1000.],
                       [0, 1, -1000.],
                       [-5, 0, 5]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       expected = math_ops.reduce_logsumexp(logx, axis=-1)
       grad_expected = gradients_impl.gradients(expected, logx)[0]
@@ -818,7 +800,7 @@
                    [1, -2, 1],
                    [1, 0, 1]])
     expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       w = constant_op.constant(w_)
       actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -836,7 +818,7 @@
                    [1, 0, 1]])
     expected, _ = self._reduce_weighted_logsumexp(
         logx_, w_, axis=-1, keep_dims=True)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       w = constant_op.constant(w_)
       actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -848,7 +830,7 @@
   def testDocString(self):
     """This test verifies the correctness of the docstring examples."""
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([[0., 0, 0],
                                 [0, 0, 0]])
 
@@ -952,7 +934,7 @@
           use_gpu=True)
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -968,7 +950,7 @@
     self.assertLess(err, 1e-4)
 
   def testInverseSoftplusGradientNeverNan(self):
-    with self.test_session():
+    with self.cached_session():
       # Note that this range contains both zero and inf.
       x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
       y = du.softplus_inverse(x)
@@ -977,7 +959,7 @@
       self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
 
   def testInverseSoftplusGradientFinite(self):
-    with self.test_session():
+    with self.cached_session():
       # This range of x is all finite, and so is 1 / x.  So the
       # gradient and its approximations should be finite as well.
       x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
diff --git a/tensorflow/python/kernel_tests/division_future_test.py b/tensorflow/python/kernel_tests/division_future_test.py
index e681b32..e477bdc 100644
--- a/tensorflow/python/kernel_tests/division_future_test.py
+++ b/tensorflow/python/kernel_tests/division_future_test.py
@@ -50,7 +50,7 @@
         self.assertEqual(x, y)
       checks.append(f)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for dtype in dtypes:
         for x in map(dtype, values):
           for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/division_past_test.py b/tensorflow/python/kernel_tests/division_past_test.py
index 9ddd62e..63951b5 100644
--- a/tensorflow/python/kernel_tests/division_past_test.py
+++ b/tensorflow/python/kernel_tests/division_past_test.py
@@ -49,7 +49,7 @@
         self.assertEqual(x, y)
       checks.append(f)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for dtype in dtypes:
         for x in map(dtype, values):
           for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py
index 529d3dd..654267a 100644
--- a/tensorflow/python/kernel_tests/duplicate_op_test.py
+++ b/tensorflow/python/kernel_tests/duplicate_op_test.py
@@ -34,7 +34,7 @@
 
     self.assertEqual(len(duplicate.OP_LIST.op), 0)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(math_ops.add(1, 41).eval(), 42)
 
 
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 5e8937a..9557e30 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -288,7 +288,7 @@
       self.assertAllEqual([], partition_vals[i])
 
   def testErrorIndexOutOfRange(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
                                    [12, 13, 14]])
       indices = constant_op.constant([0, 2, 99, 2, 2])
@@ -298,7 +298,7 @@
         sess.run(partitions)
 
   def testScalarIndexOutOfRange(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       bad = 17
       data = np.zeros(5)
       partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
@@ -306,7 +306,7 @@
         sess.run(partitions)
 
   def testHigherRankIndexOutOfRange(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shape = (2, 3)
       indices = array_ops.placeholder(shape=shape, dtype=np.int32)
       data = np.zeros(shape + (5,))
@@ -334,7 +334,7 @@
     inds += [13]*194 + [14]*194 + [15]*192
     self.assertEqual(len(inds), x.shape[0])
     partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       res = sess.run(partitioned)
     self.assertEqual(res[-1].shape[0], 192)
 
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index c4d4ce7..3a1036e 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -104,6 +104,27 @@
       # Dimension 0 is max(flatten(indices))+1.
       self.assertEqual([8, 2], stitched_t.get_shape().as_list())
 
+  def testZeroSizeTensor(self):
+    with self.test_session(use_gpu=True):
+      indices = [
+          constant_op.constant([0, 4, 7]),
+          constant_op.constant([1, 6]),
+          constant_op.constant([2, 3, 5]),
+          array_ops.zeros([0], dtype=dtypes.int32)
+      ]
+      data = [
+          constant_op.constant([[0, 1], [40, 41], [70, 71]]),
+          constant_op.constant([[10, 11], [60, 61]]),
+          constant_op.constant([[20, 21], [30, 31], [50, 51]]),
+          array_ops.zeros([0, 2], dtype=dtypes.int32)
+      ]
+      stitched_t = self.stitch_op(indices, data)
+      stitched_val = stitched_t.eval()
+      self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
+                           [50, 51], [60, 61], [70, 71]], stitched_val)
+      # Dimension 0 is max(flatten(indices))+1.
+      self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+
   def testHigherRank(self):
     with self.test_session(use_gpu=True) as sess:
       indices = [
@@ -231,7 +252,7 @@
 
   # GPU version unit tests
   def testScalarGPU(self):
-    with self.test_session():
+    with self.cached_session():
       indices = [constant_op.constant(0), constant_op.constant(1)]
       data = [constant_op.constant(40.0), constant_op.constant(60.0)]
       for step in -1, 1:
@@ -242,7 +263,7 @@
         self.assertEqual([2], stitched_t.get_shape().as_list())
 
   def testHigherRankGPU(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       indices = [
           constant_op.constant(6),
           constant_op.constant([4, 1]),
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index dcd435e..40b8548 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -242,7 +242,7 @@
   # vector is going to be empty. The subsequent DivOp fails because of that.
   # TODO(keveman): Disabling the test until the underlying problem is fixed.
   def testSimpleSharded(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 2
       vocab_size = 4
       p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
@@ -258,7 +258,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testMaxNorm(self):
-    with self.test_session():
+    with self.cached_session():
       embeddings = constant_op.constant([[2.0]])
 
       ids = constant_op.constant([0], dtype=dtypes.int32)
@@ -268,7 +268,7 @@
       self.assertAllEqual(embedding.eval(), [[1.0]])
 
   def testMaxNormNontrivial(self):
-    with self.test_session():
+    with self.cached_session():
       embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
 
       ids = constant_op.constant([0, 1], dtype=dtypes.int32)
@@ -281,7 +281,7 @@
       self.assertAllEqual(embedding.eval(), 2 * normalized.eval())
 
   def testSimpleShardedPartitionedVariable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_shards = 2
       vocab_size = 4
       p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
@@ -303,7 +303,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testSimpleShardedPartitionedResourceVariable(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_shards = 2
       vocab_size = 4
       p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable(
@@ -326,7 +326,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedModPartitioningInt32Ids(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -348,7 +348,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedModPartitioningInt64Ids(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -370,7 +370,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedDivPartitioningInt32Ids(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -394,7 +394,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedDivPartitioningInt32IdsPartitionedVariable(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -419,7 +419,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedDivPartitioningInt64Ids(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -443,7 +443,7 @@
     self.assertShapeEqual(np_result, embedding)
 
   def testShardedDivPartitioningUnknownParamShape(self):
-    with self.test_session():
+    with self.cached_session():
       num_shards = 5
       vocab_size = 13
       # Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -475,7 +475,7 @@
     tf_logging.vlog(1, id_vals)
     for ids_shape in [(10,), (2, 5)]:
       for num_shards in [1, 3]:
-        with self.test_session():
+        with self.cached_session():
           ids = constant_op.constant(
               id_vals, shape=ids_shape, dtype=dtypes.int32)
           x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
@@ -494,7 +494,7 @@
     id_vals = list(np.random.randint(vocab_size, size=num_ids))
     tf_logging.vlog(1, id_vals)
     for num_shards in [1, 3]:
-      with self.test_session():
+      with self.cached_session():
         ids = constant_op.constant(id_vals, dtype=dtypes.int32)
         x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
         # This will force a conversion from IndexedSlices to Tensor.
@@ -528,7 +528,7 @@
 
   def testHigherRank(self):
     np.random.seed(8)
-    with self.test_session():
+    with self.cached_session():
       for params_shape in (12,), (6, 3):
         params = np.random.randn(*params_shape)
         for ids_shape in (3, 2), (4, 3):
@@ -548,7 +548,7 @@
 
   def testHigherRankMaxNorm(self):
     np.random.seed(8)
-    with self.test_session():
+    with self.cached_session():
       for params_shape in (12,), (6, 3), (6, 2, 3):
         # Test embedding rank 0, 1, 2.
         # Note: the first dimension must be a common multiple of procs below.
@@ -581,7 +581,7 @@
     # It always applies max_norm.
     np.random.seed(8)
     l2_norm = 2.
-    with self.test_session():
+    with self.cached_session():
       # Param values are in [l2_norm, l2_norm+1) so it will always clip.
       params = np.random.rand(6, 3) + l2_norm
       params_norm = l2_norm * params / np.sqrt(
@@ -667,7 +667,7 @@
         [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
         [True, False]):
 
-      with self.test_session():
+      with self.cached_session():
         p, params, feed_dict = _EmbeddingParams(
             num_shards, vocab_size, shape=param_shape, dtype=dtype)
         embedding_sum = embedding_ops.embedding_lookup_sparse(
@@ -716,7 +716,7 @@
     for num_shards, combiner, dtype, ignore_weights in itertools.product(
         [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
                                            dtypes.float64], [True, False]):
-      with self.test_session():
+      with self.cached_session():
         x, params, _ = _EmbeddingParams(
             num_shards, vocab_size, shape=param_shape, dtype=dtype)
 
@@ -734,7 +734,7 @@
       self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
       sp_ids = sparse_tensor.SparseTensor(
           constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
@@ -819,7 +819,7 @@
     return sparse_ids, sparse_weights
 
   def test_safe_embedding_lookup_sparse_return_zero_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -832,7 +832,7 @@
            3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
 
   def test_safe_embedding_lookup_sparse_return_special_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -846,7 +846,7 @@
            embedding_weights[0][2], embedding_weights[0][3]])
 
   def test_safe_embedding_lookup_sparse_no_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, _ = self._ids_and_weights_2d()
 
@@ -860,7 +860,7 @@
                embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
 
   def test_safe_embedding_lookup_sparse_partitioned(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, _ = self._ids_and_weights_2d()
 
@@ -874,7 +874,7 @@
                            (embedding_weights[0] + embedding_weights[1]) / 2.0])
 
   def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, sparse_weights = self._ids_and_weights_2d()
 
@@ -889,7 +889,7 @@
                         embedding_weights, sparse_ids, sparse_weights)
 
   def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -902,7 +902,7 @@
       ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
 
   def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -918,7 +918,7 @@
             ]])
 
   def test_safe_embedding_lookup_sparse_3d_no_weights(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights()
       sparse_ids, _ = self._ids_and_weights_3d()
 
@@ -934,7 +934,7 @@
           ]])
 
   def test_safe_embedding_lookup_sparse_3d_partitioned(self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, _ = self._ids_and_weights_3d()
 
@@ -951,7 +951,7 @@
 
   def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
       self):
-    with self.test_session():
+    with self.cached_session():
       embedding_weights = self._random_weights(num_shards=3)
       sparse_ids, sparse_weights = self._ids_and_weights_3d()
 
@@ -1035,7 +1035,7 @@
 
   # We expect that the values are merged in order.
   def testStitchOrder(self):
-    with self.test_session():
+    with self.cached_session():
       indices = []
       np_values = []
       values = []
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index e1f5a6b..7d9d4e5 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -83,7 +83,7 @@
     random_seed = 42
     random_seed_lib.set_random_seed(random_seed)
 
-    with self.test_session():
+    with self.cached_session():
       for test_case in self._TEST_CASES:
         np.random.seed(random_seed)
         in_shape = test_case['in_shape']
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index 629aced..f117934 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -496,7 +496,7 @@
             "Input dimension .* must have length of at least 6 but got: 5"):
           x = np.zeros((5,) * rank).astype(np.float32)
           fft_length = [6] * rank
-          with self.test_session():
+          with self.cached_session():
             rfft_fn(x, fft_length).eval()
 
         with self.assertRaisesWithPredicateMatch(
@@ -504,7 +504,7 @@
             "Input dimension .* must have length of at least .* but got: 3"):
           x = np.zeros((3,) * rank).astype(np.complex64)
           fft_length = [6] * rank
-          with self.test_session():
+          with self.cached_session():
             irfft_fn(x, fft_length).eval()
 
   def testGrad_Simple(self):
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 9e7b528..a5f8f64 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -99,19 +99,19 @@
       """, q.queue_ref.op.node_def)
 
   def testEnqueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       enqueue_op = q.enqueue((10.0,))
       enqueue_op.run()
 
   def testEnqueueHalf(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
       enqueue_op = q.enqueue((10.0,))
       enqueue_op.run()
 
   def testEnqueueWithShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
       enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
       enqueue_correct_op.run()
@@ -120,7 +120,7 @@
       self.assertEqual(1, q.size().eval())
 
   def testEnqueueManyWithShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(
           10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
       q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
@@ -143,7 +143,7 @@
     self.assertAllEqual(self.evaluate(q.dequeue()), 1)
 
   def testEnqueueDictWithoutNames(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       with self.assertRaisesRegexp(ValueError, "must have names"):
         q.enqueue({"a": 12.0})
@@ -151,7 +151,7 @@
         q.enqueue_many({"a": [12.0, 13.0]})
 
   def testParallelEnqueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -177,7 +177,7 @@
       self.assertItemsEqual(elems, results)
 
   def testParallelDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -201,7 +201,7 @@
       self.assertItemsEqual(elems, results)
 
   def testDequeue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -215,7 +215,7 @@
         self.assertEqual([elems[i]], vals)
 
   def testDequeueHalf(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
       elems = [10.0, 20.0, 30.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -229,7 +229,7 @@
         self.assertEqual([elems[i]], vals)
 
   def testEnqueueAndBlockingDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -259,7 +259,7 @@
         self.assertEqual([elem], result)
 
   def testMultiEnqueueAndDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
       elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
       enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@@ -275,12 +275,12 @@
         self.assertEqual([y], y_val)
 
   def testQueueSizeEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       self.assertEqual([0], q.size().eval())
 
   def testQueueSizeAfterEnqueueAndDequeue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       enqueue_op = q.enqueue((10.0,))
       dequeued_t = q.dequeue()
@@ -293,7 +293,7 @@
       self.assertEqual(0, size.eval())
 
   def testEnqueueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -306,7 +306,7 @@
         self.assertEqual([elems[i % 4]], vals)
 
   def testEmptyEnqueueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       empty_t = constant_op.constant(
           [], dtype=dtypes_lib.float32, shape=[0, 2, 3])
@@ -318,7 +318,7 @@
       self.assertEqual([0], size_t.eval())
 
   def testEmptyDequeueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
       enqueue_op = q.enqueue((10.0,))
       dequeued_t = q.dequeue_many(0)
@@ -328,7 +328,7 @@
       self.assertEqual([], dequeued_t.eval().tolist())
 
   def testEmptyDequeueUpTo(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
       enqueue_op = q.enqueue((10.0,))
       dequeued_t = q.dequeue_up_to(0)
@@ -338,14 +338,14 @@
       self.assertEqual([], dequeued_t.eval().tolist())
 
   def testEmptyDequeueManyWithNoShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       # Expect the operation to fail due to the shape not being constrained.
       with self.assertRaisesOpError("specified shapes"):
         q.dequeue_many(0).eval()
 
   def testMultiEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, (dtypes_lib.float32, dtypes_lib.int32))
       float_elems = [10.0, 20.0, 30.0, 40.0]
       int_elems = [[1, 2], [3, 4], [5, 6], [7, 8]]
@@ -361,7 +361,7 @@
         self.assertAllEqual(int_elems[i % 4], int_val)
 
   def testDequeueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@
       self.assertAllEqual(elems[4:8], dequeued_t.eval())
 
   def testDequeueUpToNoBlocking(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -385,7 +385,7 @@
       self.assertAllEqual(elems[4:8], dequeued_t.eval())
 
   def testMultiDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
       float_elems = [
@@ -416,7 +416,7 @@
       self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
 
   def testMultiDequeueUpToNoBlocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
       float_elems = [
@@ -440,7 +440,7 @@
       self.assertAllEqual(int_elems[4:8], int_val)
 
   def testHighDimension(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, (4, 4, 4, 4))
       elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
       enqueue_op = q.enqueue_many((elems,))
@@ -494,7 +494,7 @@
                       array_ops.placeholder(dtypes_lib.int32)))
 
   def testEnqueueWrongShapeAtRuntime(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
           (2, 2), (3, 3)))
       elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
@@ -506,7 +506,7 @@
                  feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
 
   def testEnqueueDequeueManyWrongShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
           (2, 2), (3, 3)))
       elems_ok = np.array([1] * 8).reshape((2, 2, 2)).astype(np.int32)
@@ -521,7 +521,7 @@
         dequeued_t.eval()
 
   def testParallelEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
       elems = [10.0 * x for x in range(100)]
       enqueue_op = q.enqueue_many((elems,))
@@ -540,7 +540,7 @@
       self.assertItemsEqual(dequeued_t.eval(), elems * 10)
 
   def testParallelDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
       elems = [10.0 * x for x in range(1000)]
       enqueue_op = q.enqueue_many((elems,))
@@ -562,7 +562,7 @@
       self.assertItemsEqual(elems, dequeued_elems)
 
   def testParallelDequeueUpTo(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
       elems = [10.0 * x for x in range(1000)]
       enqueue_op = q.enqueue_many((elems,))
@@ -586,7 +586,7 @@
       self.assertItemsEqual(elems, dequeued_elems)
 
   def testParallelEnqueueAndDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(50, dtypes_lib.float32, shapes=())
       initial_elements = [10.0] * 49
       q.enqueue_many((initial_elements,)).run()
@@ -619,7 +619,7 @@
         self.assertTrue(elem in (10.0, 20.0))
 
   def testMixtureOfEnqueueAndEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
       enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
       enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -655,7 +655,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testMixtureOfDequeueAndDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
       enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
       dequeued_t = q.dequeue()
@@ -689,7 +689,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -716,7 +716,7 @@
       self.assertAllEqual(elems, dequeued_elems)
 
   def testBlockingDequeueUpTo(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -743,7 +743,7 @@
       self.assertAllEqual(elems, dequeued_elems)
 
   def testDequeueManyWithTensorParameter(self):
-    with self.test_session():
+    with self.cached_session():
       # Define a first queue that contains integer counts.
       dequeue_counts = [random.randint(1, 10) for _ in range(100)]
       count_q = data_flow_ops.FIFOQueue(100, dtypes_lib.int32, ())
@@ -768,7 +768,7 @@
       self.assertEqual(elems, dequeued_elems)
 
   def testDequeueFromClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -786,7 +786,7 @@
         dequeued_t.eval()
 
   def testBlockingDequeueFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -812,7 +812,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       close_op = q.close()
       dequeued_t = q.dequeue()
@@ -832,7 +832,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueManyFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -857,7 +857,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueManyButNotAllFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -882,7 +882,7 @@
       dequeue_thread.join()
 
   def testDequeueUpToFromClosedQueueReturnsRemainder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -904,7 +904,7 @@
       dequeue_thread.join()
 
   def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32, ())
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -941,7 +941,7 @@
       close_thread.join()
 
   def testClosedBlockingDequeueManyRestoresPartialBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, (dtypes_lib.float32, dtypes_lib.float32), (
           (), ()))
       elems_a = [1.0, 2.0, 3.0]
@@ -974,7 +974,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingDequeueManyFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       close_op = q.close()
       dequeued_t = q.dequeue_many(4)
@@ -994,7 +994,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueUpToFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
       close_op = q.close()
       dequeued_t = q.dequeue_up_to(4)
@@ -1014,7 +1014,7 @@
       dequeue_thread.join()
 
   def testEnqueueToClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       enqueue_op = q.enqueue((10.0,))
       close_op = q.close()
@@ -1027,7 +1027,7 @@
         enqueue_op.run()
 
   def testEnqueueManyToClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1041,7 +1041,7 @@
         enqueue_op.run()
 
   def testBlockingEnqueueToFullQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1064,7 +1064,7 @@
       thread.join()
 
   def testBlockingEnqueueManyToFullQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1091,7 +1091,7 @@
       thread.join()
 
   def testBlockingEnqueueBeforeClose(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1128,7 +1128,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingEnqueueManyBeforeClose(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
       elems = [10.0, 20.0, 30.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1161,7 +1161,7 @@
         self.assertEqual(elem, dequeued_t.eval())
 
   def testDoesNotLoseValue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.FIFOQueue(1, dtypes_lib.float32)
       enqueue_op = q.enqueue((10.0,))
       size_t = q.size()
@@ -1171,7 +1171,7 @@
         self.assertEqual(size_t.eval(), [1])
 
   def testSharedQueueSameSession(self):
-    with self.test_session():
+    with self.cached_session():
       q1 = data_flow_ops.FIFOQueue(
           1, dtypes_lib.float32, shared_name="shared_queue")
       q1.enqueue((10.0,)).run()
@@ -1201,7 +1201,7 @@
       self.assertEqual(q2_size_t.eval(), [0])
 
   def testIncompatibleSharedQueueErrors(self):
-    with self.test_session():
+    with self.cached_session():
       q_a_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_a")
       q_a_2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32, shared_name="q_a")
       q_a_1.queue_ref.op.run()
@@ -1244,7 +1244,7 @@
         q_f_2.queue_ref.op.run()
 
   def testSelectQueue(self):
-    with self.test_session():
+    with self.cached_session():
       num_queues = 10
       qlist = list()
       for _ in xrange(num_queues):
@@ -1257,7 +1257,7 @@
         self.assertEqual(q.dequeue().eval(), 10.0)
 
   def testSelectQueueOutOfRange(self):
-    with self.test_session():
+    with self.cached_session():
       q1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       q2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32)
       enq_q = data_flow_ops.FIFOQueue.from_list(3, [q1, q2])
@@ -1281,7 +1281,7 @@
       sess.run(enqueue_many_op)
 
   def testResetOfBlockingOperation(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q_empty = data_flow_ops.FIFOQueue(5, dtypes_lib.float32, ())
       dequeue_op = q_empty.dequeue()
       dequeue_many_op = q_empty.dequeue_many(1)
@@ -1309,7 +1309,7 @@
         t.join()
 
   def testBigEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(5, dtypes_lib.int32, ((),))
       elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
       enq = q.enqueue_many((elem,))
@@ -1354,7 +1354,7 @@
       self.assertAllEqual(elem, results)
 
   def testBigDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(2, dtypes_lib.int32, ((),))
       elem = np.arange(4, dtype=np.int32)
       enq_list = [q.enqueue((e,)) for e in elem]
@@ -1380,7 +1380,7 @@
       self.assertAllEqual(elem, results)
 
   def testDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dtypes = [
           dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
           dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
@@ -1411,7 +1411,7 @@
         self.assertAllEqual(input_elem, output_elem)
 
   def testDequeueEnqueueFail(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
       a = q.dequeue()
       b = control_flow_ops.Assert(False, ["Before enqueue"])
@@ -1474,7 +1474,7 @@
     self.assertEqual(["i", "f"], q.names)
 
   def testEnqueueDequeueOneComponent(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(
           10, dtypes_lib.float32, shapes=((),), names="f")
       # Verify that enqueue() checks that when using names we must enqueue a
@@ -1519,7 +1519,7 @@
       self.assertEqual([40.0, 50.0], list(f))
 
   def testEnqueueDequeueMultipleComponent(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32, dtypes_lib.string),
           shapes=((), (), ()),
@@ -1600,7 +1600,7 @@
         sess.run(dequeued_t)
 
   def testReusableAfterTimeout(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
       dequeued_t = q.dequeue()
       enqueue_op = q.enqueue(37)
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
index faac7d8..f89d206 100644
--- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -127,7 +127,7 @@
     Returns:
       None
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       p, r, c = nn_ops.fractional_avg_pool(
           input_tensor,
           pooling_ratio,
@@ -160,7 +160,7 @@
           overlapping))
       rand_mat = self._PRNG.randint(10, size=tensor_shape)
       pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         p, r, c = nn_ops.fractional_avg_pool(
             rand_mat.astype(np.float32),
             pooling_ratio,
@@ -234,7 +234,7 @@
         [4, 4, 5, 9, 7, 2]
     ])
     # pyformat: enable
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
       # are the same each time. We can have an expected result precomputed.
       # r = [0, 2, 4, 6]
@@ -314,7 +314,7 @@
 
   def testDifferentInputTensorShape(self):
     """Runs the operation in one session with different input tensor shapes."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_holder = array_ops.placeholder(dtypes.float32,
                                            [None, None, None, 3])
       pooling_ratio = [1, 1.5, 1.5, 1]
@@ -389,7 +389,7 @@
           num_cols = col_window_size * 7
           for num_channels in [1, 2]:
             input_shape = (num_batches, num_rows, num_cols, num_channels)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(
                   self._GenerateRandomInputTensor(input_shape).astype(
                       np.float32))
@@ -428,7 +428,7 @@
           num_cols = (col_window_size - 1) * 7 + 1
           for num_channels in [1, 2]:
             input_shape = (num_batches, num_rows, num_cols, num_channels)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(
                   self._GenerateRandomInputTensor(input_shape).astype(
                       np.float32))
@@ -468,7 +468,7 @@
 
     for pseudo_random in True, False:
       for overlapping in True, False:
-        with self.test_session() as _:
+        with self.cached_session() as _:
           input_tensor = constant_op.constant(input_data, shape=input_shape)
           output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
               input_tensor,
@@ -501,7 +501,7 @@
           for num_channels in [1, 3]:
             input_shape = (num_batches, num_rows, num_cols, num_channels)
             input_data = self._GenerateRandomInputTensor(input_shape)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(input_data, shape=input_shape)
               output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
                   input_tensor,
@@ -532,7 +532,7 @@
     overlapping = True
     pseudo_random = False
 
-    with self.test_session() as _:
+    with self.cached_session() as _:
       input_tensor = constant_op.constant(input_data, shape=input_shape)
       output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
           input_tensor,
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
index 6477c9e..9b94ca8 100644
--- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -127,7 +127,7 @@
     Returns:
       None
     """
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       p, r, c = nn_ops.fractional_max_pool(
           input_tensor,
           pooling_ratio,
@@ -160,7 +160,7 @@
           overlapping))
       rand_mat = self._PRNG.randint(10, size=tensor_shape)
       pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         p, r, c = nn_ops.fractional_max_pool(
             rand_mat,
             pooling_ratio,
@@ -285,7 +285,7 @@
 
   def testDifferentInputTensorShape(self):
     """Runs the operation in one session with different input tensor shapes."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_holder = array_ops.placeholder(dtypes.float32,
                                            [None, None, None, 3])
       pooling_ratio = [1, 1.5, 1.5, 1]
@@ -374,7 +374,7 @@
           num_cols = col_window_size * 7
           for num_channels in [1, 2]:
             input_shape = (num_batches, num_rows, num_cols, num_channels)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(
                   self._GenerateUniqueRandomInputTensor(input_shape))
               window_size = [1, row_window_size, col_window_size, 1]
@@ -409,7 +409,7 @@
           num_cols = (col_window_size - 1) * 7 + 1
           for num_channels in [1, 2]:
             input_shape = (num_batches, num_rows, num_cols, num_channels)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(
                   self._GenerateUniqueRandomInputTensor(input_shape))
               window_size = [1, row_window_size, col_window_size, 1]
@@ -447,7 +447,7 @@
 
     for pseudo_random in True, False:
       for overlapping in True, False:
-        with self.test_session() as _:
+        with self.cached_session() as _:
           input_tensor = constant_op.constant(input_data, shape=input_shape)
           output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
               input_tensor,
@@ -482,7 +482,7 @@
             input_data = self._GenerateUniqueRandomInputTensor(input_shape)
             # Add some randomness to make input_data not so 'integer'
             input_data += self._PRNG.random_sample(input_shape)
-            with self.test_session() as _:
+            with self.cached_session() as _:
               input_tensor = constant_op.constant(input_data, shape=input_shape)
               output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
                   input_tensor,
@@ -515,7 +515,7 @@
     overlapping = True
     pseudo_random = False
 
-    with self.test_session() as _:
+    with self.cached_session() as _:
       input_tensor = constant_op.constant(input_data, shape=input_shape)
       output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
           input_tensor,
@@ -579,7 +579,7 @@
          0.0, 0.0, 0.0, 0.0,
          6.0, 0.0, 21.0, 0.0],
         input_size)  # pyformat: disable
-    with self.test_session() as _:
+    with self.cached_session() as _:
       # Test when overlapping is False
       input_tensor = constant_op.constant(input_data, shape=input_size)
       output_tensor = constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 1e76ad7..e39daf1 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -20,6 +20,7 @@
 
 import numpy as np
 
+from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.data.ops import iterator_ops
@@ -59,42 +60,48 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldl_Simple(self):
-    with self.test_session():
-      elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
 
-      r = functional_ops.foldl(
-          lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
-          elems)
-      self.assertAllEqual(208, self.evaluate(r))
+    r = functional_ops.foldl(
+        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+        elems)
+    self.assertAllEqual(208, self.evaluate(r))
 
-      r = functional_ops.foldl(
-          lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
-          elems,
-          initializer=10)
-      self.assertAllEqual(880, self.evaluate(r))
+    r = functional_ops.foldl(
+        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+        elems,
+        initializer=10)
+    self.assertAllEqual(880, self.evaluate(r))
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldl_SingleInputMultiOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array([1, -1.0])
-      r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
-      r_value = self.evaluate(r)
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array([1, -1.0])
+    r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
+    r_value = self.evaluate(r)
 
-      self.assertAllEqual(22, r_value[0])
-      self.assertAllEqual(20, r_value[1])
+    self.assertAllEqual(22, r_value[0])
+    self.assertAllEqual(20, r_value[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldl_MultiInputSingleOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array(1.0)
-      r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
-                               initializer)
-      self.assertAllEqual(1, self.evaluate(r))
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array(1.0)
+    r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
+                             initializer)
+    self.assertAllEqual(1, self.evaluate(r))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testFoldl_MultiInputDifferentDimsSingleOutput(self):
+    elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
+    other_elems = np.array([-1.0, 1.0])
+    initializer = np.array([0.0, 0.0, 0.0])
+    r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
+                             (elems, other_elems), initializer)
+    self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
 
   def testFoldl_Scoped(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope("root") as varscope:
         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
 
@@ -114,42 +121,39 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldr_Simple(self):
-    with self.test_session():
-      elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
+    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
 
-      r = functional_ops.foldr(
-          lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
-          elems)
-      self.assertAllEqual(450, self.evaluate(r))
+    r = functional_ops.foldr(
+        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+        elems)
+    self.assertAllEqual(450, self.evaluate(r))
 
-      r = functional_ops.foldr(
-          lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
-          elems,
-          initializer=10)
-      self.assertAllEqual(1282, self.evaluate(r))
+    r = functional_ops.foldr(
+        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
+        elems,
+        initializer=10)
+    self.assertAllEqual(1282, self.evaluate(r))
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldr_SingleInputMultiOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array([1, -1.0])
-      r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
-      r_value = self.evaluate(r)
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array([1, -1.0])
+    r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
+    r_value = self.evaluate(r)
 
-      self.assertAllEqual(22, r_value[0])
-      self.assertAllEqual(20, r_value[1])
+    self.assertAllEqual(22, r_value[0])
+    self.assertAllEqual(20, r_value[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldr_MultiInputSingleOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array(1.0)
-      r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
-                               initializer)
-      self.assertAllEqual(1, self.evaluate(r))
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array(1.0)
+    r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
+                             initializer)
+    self.assertAllEqual(1, self.evaluate(r))
 
   def testFoldr_Scoped(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope("root") as varscope:
         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
 
@@ -169,7 +173,7 @@
 
   # pylint: disable=unnecessary-lambda
   def testFold_Grad(self):
-    with self.test_session():
+    with self.cached_session():
       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
       v = constant_op.constant(2.0, name="v")
       r = functional_ops.foldl(
@@ -185,16 +189,15 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_Simple(self):
-    with self.test_session():
-      nums = [1, 2, 3, 4, 5, 6]
-      elems = constant_op.constant(nums, name="data")
-      r = functional_ops.map_fn(
-          lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
-      self.assertAllEqual(
-          np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+    nums = [1, 2, 3, 4, 5, 6]
+    elems = constant_op.constant(nums, name="data")
+    r = functional_ops.map_fn(
+        lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
+    self.assertAllEqual(
+        np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
 
   def testMapSparseTensor(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         functional_ops.map_fn(
             lambda x: x,
@@ -211,7 +214,7 @@
       functional_ops.map_fn(lambda x: x, 1)
 
   def testMap_Scoped(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       def double_scoped(x):
         """2x with a dummy 2 that is scoped."""
@@ -242,7 +245,7 @@
         self.assertAllEqual(doubles, self.evaluate(r))
 
   def testMap_Grad(self):
-    with self.test_session():
+    with self.cached_session():
       param = constant_op.constant(2.0)
       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
       y = functional_ops.map_fn(
@@ -254,142 +257,131 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_SimpleNotTensor(self):
-    with self.test_session():
-      nums = np.array([1, 2, 3, 4, 5, 6])
-      r = functional_ops.map_fn(
-          lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
-      self.assertAllEqual(
-          np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
+    nums = np.array([1, 2, 3, 4, 5, 6])
+    r = functional_ops.map_fn(
+        lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
+    self.assertAllEqual(
+        np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_SingleInputMultiOutput(self):
-    with self.test_session():
-      nums = np.array([1, 2, 3, 4, 5, 6])
-      r = functional_ops.map_fn(
-          lambda x: ((x + 3) * 2, -(x + 3) * 2),
-          nums,
-          dtype=(dtypes.int64, dtypes.int64))
-      self.assertEqual(2, len(r))
-      self.assertEqual((6,), r[0].get_shape())
-      self.assertEqual((6,), r[1].get_shape())
-      received = self.evaluate(r)
-      self.assertAllEqual((nums + 3) * 2, received[0])
-      self.assertAllEqual(-(nums + 3) * 2, received[1])
+    nums = np.array([1, 2, 3, 4, 5, 6])
+    r = functional_ops.map_fn(
+        lambda x: ((x + 3) * 2, -(x + 3) * 2),
+        nums,
+        dtype=(dtypes.int64, dtypes.int64))
+    self.assertEqual(2, len(r))
+    self.assertEqual((6,), r[0].get_shape())
+    self.assertEqual((6,), r[1].get_shape())
+    received = self.evaluate(r)
+    self.assertAllEqual((nums + 3) * 2, received[0])
+    self.assertAllEqual(-(nums + 3) * 2, received[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_MultiOutputMismatchedDtype(self):
-    with self.test_session():
-      nums = np.array([1, 2, 3, 4, 5, 6])
-      with self.assertRaisesRegexp(
-          TypeError, r"two structures don't have the same nested structure"):
-        # lambda emits tuple, but dtype is a list
-        functional_ops.map_fn(
-            lambda x: ((x + 3) * 2, -(x + 3) * 2),
-            nums,
-            dtype=[dtypes.int64, dtypes.int64])
+    nums = np.array([1, 2, 3, 4, 5, 6])
+    with self.assertRaisesRegexp(
+        TypeError, r"two structures don't have the same nested structure"):
+      # lambda emits tuple, but dtype is a list
+      functional_ops.map_fn(
+          lambda x: ((x + 3) * 2, -(x + 3) * 2),
+          nums,
+          dtype=[dtypes.int64, dtypes.int64])
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_MultiInputSingleOutput(self):
-    with self.test_session():
-      nums = np.array([1, 2, 3, 4, 5, 6])
-      r = functional_ops.map_fn(
-          lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
-          dtype=dtypes.int64)
-      self.assertEqual((6,), r.get_shape())
-      received = self.evaluate(r)
-      self.assertAllEqual(nums * nums + (-nums), received)
+    nums = np.array([1, 2, 3, 4, 5, 6])
+    r = functional_ops.map_fn(
+        lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
+        dtype=dtypes.int64)
+    self.assertEqual((6,), r.get_shape())
+    received = self.evaluate(r)
+    self.assertAllEqual(nums * nums + (-nums), received)
 
   @test_util.run_in_graph_and_eager_modes
   def testMap_MultiInputSameStructureOutput(self):
-    with self.test_session():
-      nums = np.array([1, 2, 3, 4, 5, 6])
-      r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
-                                (nums, (2 * nums, -nums)))
-      r = [r[0], r[1][0], r[1][1]]
-      self.assertEqual((6,), r[0].get_shape())
-      self.assertEqual((6,), r[1].get_shape())
-      self.assertEqual((6,), r[2].get_shape())
-      received = self.evaluate(r)
-      self.assertAllEqual(2 * nums, received[0])
-      self.assertAllEqual(-nums, received[1])
-      self.assertAllEqual(nums, received[2])
+    nums = np.array([1, 2, 3, 4, 5, 6])
+    r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
+                              (nums, (2 * nums, -nums)))
+    r = [r[0], r[1][0], r[1][1]]
+    self.assertEqual((6,), r[0].get_shape())
+    self.assertEqual((6,), r[1].get_shape())
+    self.assertEqual((6,), r[2].get_shape())
+    received = self.evaluate(r)
+    self.assertAllEqual(2 * nums, received[0])
+    self.assertAllEqual(-nums, received[1])
+    self.assertAllEqual(nums, received[2])
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_Simple(self):
-    with self.test_session():
-      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
-      v = constant_op.constant(2.0, name="v")
+    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+    v = constant_op.constant(2.0, name="v")
 
-      # pylint: disable=unnecessary-lambda
-      r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
-      self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
+    # pylint: disable=unnecessary-lambda
+    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
+    self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
 
-      r = functional_ops.scan(
-          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
-      self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
-      # pylint: enable=unnecessary-lambda
+    r = functional_ops.scan(
+        lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
+    self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
+    # pylint: enable=unnecessary-lambda
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_Reverse(self):
-    with self.test_session():
-      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
-      v = constant_op.constant(2.0, name="v")
+    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
+    v = constant_op.constant(2.0, name="v")
 
-      # pylint: disable=unnecessary-lambda
-      r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
-                              reverse=True)
-      self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
-      r = functional_ops.scan(
-          lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
-          reverse=True)
-      self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
-                          self.evaluate(r))
-      # pylint: enable=unnecessary-lambda
+    # pylint: disable=unnecessary-lambda
+    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
+                            reverse=True)
+    self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
+    r = functional_ops.scan(
+        lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
+        reverse=True)
+    self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
+                        self.evaluate(r))
+    # pylint: enable=unnecessary-lambda
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_SingleInputMultiOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = (np.array(1.0), np.array(-1.0))
-      r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
-                              initializer)
-      r_value = self.evaluate(r)
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = (np.array(1.0), np.array(-1.0))
+    r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
+                            initializer)
+    r_value = self.evaluate(r)
 
-      self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
-      self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
+    self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
+    self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_MultiInputSingleOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array(1.0)
-      # Multiply a * 1 each time
-      r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
-                              (elems + 1, -elems), initializer)
-      self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array(1.0)
+    # Multiply a * 1 each time
+    r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
+                            (elems + 1, -elems), initializer)
+    self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_MultiInputSameTypeOutput(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
-                              (elems, -elems))
-      r_value = self.evaluate(r)
-      self.assertAllEqual(np.cumsum(elems), r_value[0])
-      self.assertAllEqual(np.cumsum(-elems), r_value[1])
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
+                            (elems, -elems))
+    r_value = self.evaluate(r)
+    self.assertAllEqual(np.cumsum(elems), r_value[0])
+    self.assertAllEqual(np.cumsum(-elems), r_value[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testScan_MultiOutputMismatchedInitializer(self):
-    with self.test_session():
-      elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
-      initializer = np.array(1.0)
-      # Multiply a * 1 each time
-      with self.assertRaisesRegexp(
-          ValueError, "two structures don't have the same nested structure"):
-        functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
+    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+    initializer = np.array(1.0)
+    # Multiply a * 1 each time
+    with self.assertRaisesRegexp(
+        ValueError, "two structures don't have the same nested structure"):
+      functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
 
   def testScan_Scoped(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope("root") as varscope:
         elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
 
@@ -411,30 +403,29 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testScanFoldl_Nested(self):
-    with self.test_session():
-      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
-      inner_elems = constant_op.constant([0.5, 0.5], name="data")
+    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
+    inner_elems = constant_op.constant([0.5, 0.5], name="data")
 
-      def r_inner(a, x):
-        return functional_ops.foldl(
-            lambda b, y: b * y * x, inner_elems, initializer=a)
+    def r_inner(a, x):
+      return functional_ops.foldl(
+          lambda b, y: b * y * x, inner_elems, initializer=a)
 
-      r = functional_ops.scan(r_inner, elems)
+    r = functional_ops.scan(r_inner, elems)
 
-      # t == 0 (returns 1)
-      # t == 1, a == 1, x == 2 (returns 1)
-      #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
-      #   t_1 == 1, b == 1,      y == 0.5, returns b * y * x = 1
-      # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
-      #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
-      #   t_1 == 1, b == 1.5,    y == 0.5, returns b * y * x = 1.5*1.5
-      # t == 3, a == 2.25, x == 4 (returns 9)
-      #   t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
-      #   t_1 == 1, b == 4.5,       y == 0.5, returns b * y * x = 9
-      self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
+    # t == 0 (returns 1)
+    # t == 1, a == 1, x == 2 (returns 1)
+    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
+    #   t_1 == 1, b == 1,      y == 0.5, returns b * y * x = 1
+    # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
+    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
+    #   t_1 == 1, b == 1.5,    y == 0.5, returns b * y * x = 1.5*1.5
+    # t == 3, a == 2.25, x == 4 (returns 9)
+    #   t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
+    #   t_1 == 1, b == 4.5,       y == 0.5, returns b * y * x = 9
+    self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
 
   def testScan_Control(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       s = array_ops.placeholder(dtypes.float32, shape=[None])
       b = array_ops.placeholder(dtypes.bool)
 
@@ -445,7 +436,7 @@
                                                   b: True}))
 
   def testScan_Grad(self):
-    with self.test_session():
+    with self.cached_session():
       elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
       v = constant_op.constant(2.0, name="v")
 
@@ -470,22 +461,20 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testFoldShape(self):
-    with self.test_session():
-      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
 
-      def fn(_, current_input):
-        return current_input
+    def fn(_, current_input):
+      return current_input
 
-      initializer = constant_op.constant([0, 0, 0])
-      y = functional_ops.foldl(fn, x, initializer=initializer)
-      self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+    initializer = constant_op.constant([0, 0, 0])
+    y = functional_ops.foldl(fn, x, initializer=initializer)
+    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
 
   @test_util.run_in_graph_and_eager_modes
   def testMapShape(self):
-    with self.test_session():
-      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
-      y = functional_ops.map_fn(lambda e: e, x)
-      self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+    y = functional_ops.map_fn(lambda e: e, x)
+    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
 
   def testMapUnknownShape(self):
     x = array_ops.placeholder(dtypes.float32)
@@ -494,15 +483,14 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testMapEmptyScalar(self):
-    with self.test_session():
-      map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
-      self.assertAllEqual([0], map_return.get_shape().dims)
-      self.assertAllEqual([0], self.evaluate(map_return).shape)
+    map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
+    self.assertAllEqual([0], map_return.get_shape().dims)
+    self.assertAllEqual([0], self.evaluate(map_return).shape)
 
   # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
   # so the body of the while loop never executes
   def testMapEmptyTensor(self):
-    with self.test_session():
+    with self.cached_session():
       map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]),
                                          constant_op.constant([]))
       self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
@@ -510,20 +498,19 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testScanShape(self):
-    with self.test_session():
-      x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
 
-      def fn(_, current_input):
-        return current_input
+    def fn(_, current_input):
+      return current_input
 
-      initializer = constant_op.constant([0, 0, 0])
-      y = functional_ops.scan(fn, x, initializer=initializer)
-      self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
+    initializer = constant_op.constant([0, 0, 0])
+    y = functional_ops.scan(fn, x, initializer=initializer)
+    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
 
   # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
   # so the body of the while loop never executes
   def testScanEmptyTensor(self):
-    with self.test_session():
+    with self.cached_session():
       x = functional_ops.scan(
           lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
       self.assertAllEqual([0, 2, 4], x.get_shape())
@@ -540,7 +527,7 @@
     self.assertIs(None, y.get_shape().dims)
 
   def testScanVaryingShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
       x_t = array_ops.transpose(x)
       # scan over dimension 0 (with shape None)
@@ -619,7 +606,7 @@
       remote_op = functional_ops.remote_call(
           args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       mul = sess.run(remote_op)
       self.assertEqual(mul, [6])
@@ -643,7 +630,7 @@
           f=_remote_fn,
           target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       mul = sess.run(remote_op)
       self.assertEqual(mul, 9.0)
@@ -667,7 +654,7 @@
           f=_remote_fn,
           target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.global_variables_initializer())
       mul = sess.run(remote_op)
       self.assertEqual(mul, 9.0)
@@ -686,7 +673,7 @@
       remote_op = functional_ops.remote_call(
           args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ret = sess.run(remote_op)
       self.assertAllEqual(ret, [b"a"])
 
@@ -752,6 +739,40 @@
           self.assertAllEqual(Run(sess, 20.), 210.)
           self.assertAllEqual(Run(sess, 100.), 5050.)
 
+  def testWhileLowering(self):
+
+    def Run(n, fetch_by_name):
+      for use_gpu in (True, False):
+        with ops.Graph().as_default() as g:
+
+          @function.Defun(*[dtypes.float32] * 2)
+          def Cond(n, unused_x):
+            return n > 0
+
+          @function.Defun(*[dtypes.float32] * 2)
+          def Body(n, x):
+            return n - 1, x + n
+
+          # outputs: [0, n*(n+1)/2]
+          outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
+
+          # `outputs` is the list of output tensors of the While op. We
+          # arbitrarily choose the 0th tensor to get the While op and set the
+          # lowering attribute on it.
+          outputs[0].op._set_attr("_lower_using_switch_merge",
+                                  attr_value_pb2.AttrValue(b=True))
+          if not fetch_by_name:
+            fetch = outputs[1]
+          else:
+            fetch = "my_while:1"
+        with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+          return sess.run(fetch)
+
+    self.assertAllEqual(Run(20., False), 210.)
+    self.assertAllEqual(Run(20., True), 210.)
+    self.assertAllEqual(Run(100., False), 5050.)
+    self.assertAllEqual(Run(100., True), 5050.)
+
   def testWhileError(self):
     for use_gpu in (True, False):
       with ops.Graph().as_default() as g:
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 033fa95..85bf969 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -147,7 +147,7 @@
 
   def testString(self):
     params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual([b"qwer", b"uiop"],
                           array_ops.gather(params, 1, axis=0).eval())
       self.assertAllEqual([b"asdf", b"qwer"],
@@ -157,7 +157,7 @@
     for unsigned_type in (dtypes.uint32, dtypes.uint64):
       params = self._buildParams(
           np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
-      with self.test_session():
+      with self.cached_session():
         self.assertAllEqual([7, 8, 9],
                             array_ops.gather(params, 1, axis=0).eval())
         self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
diff --git a/tensorflow/python/kernel_tests/gradient_correctness_test.py b/tensorflow/python/kernel_tests/gradient_correctness_test.py
index e93c623..291a69e 100644
--- a/tensorflow/python/kernel_tests/gradient_correctness_test.py
+++ b/tensorflow/python/kernel_tests/gradient_correctness_test.py
@@ -30,7 +30,7 @@
 class GradientCorrectnessTest(test.TestCase):
 
   def testMultipleOutputChainedGradients(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = constant_op.constant(1.0, dtype=dtypes.float32)
       yexp = math_ops.exp(x)
       yexplog = math_ops.log(yexp)
@@ -43,13 +43,13 @@
   def testIdentityGradient(self):
     x = constant_op.constant(3.)
     dx_dx, = gradients_impl.gradients(x, x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllClose(1., sess.run(dx_dx))
 
   def testIntegerIdentityGradient(self):
     x = constant_op.constant(3)
     dx_dx, = gradients_impl.gradients(x, x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllClose(1, sess.run(dx_dx))
 
   def testGradientWithIntegerPath(self):
@@ -57,7 +57,7 @@
     k = math_ops.to_float(math_ops.to_int32(x))
     y = x * k
     dy_dx, = gradients_impl.gradients(y, x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllClose([3., 4.], sess.run(dy_dx))
 
   def testNoIntegerGradient1(self):
diff --git a/tensorflow/python/kernel_tests/identity_n_op_py_test.py b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
index 408b173..518733c 100644
--- a/tensorflow/python/kernel_tests/identity_n_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
@@ -28,7 +28,7 @@
 class IdentityNOpTest(test.TestCase):
 
   def testInt32String_6(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       [value0, value1] = sess.run(
           array_ops.identity_n([[1, 2, 3, 4, 5, 6],
                                 [b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))
@@ -37,7 +37,7 @@
         np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
 
   def testInt32_shapes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       inp0 = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
       inp1 = constant_op.constant([11, 21, 31, 41, 51, 61], shape=[3, 2])
       inp2 = constant_op.constant(
@@ -52,12 +52,12 @@
 
   def testString(self):
     source = [b"A", b"b", b"C", b"d", b"E", b"f"]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       [value] = sess.run(array_ops.identity_n([source]))
     self.assertAllEqual(source, value)
 
   def testIdentityShape(self):
-    with self.test_session():
+    with self.cached_session():
       shape = [2, 3]
       array_2x3 = [[1, 2, 3], [6, 5, 4]]
       tensor = constant_op.constant(array_2x3)
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 49fb76d..37f9f71 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -31,24 +31,24 @@
 class IdentityOpTest(test.TestCase):
 
   def testInt32_6(self):
-    with self.test_session():
+    with self.cached_session():
       value = array_ops.identity([1, 2, 3, 4, 5, 6]).eval()
     self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value)
 
   def testInt32_2_3(self):
-    with self.test_session():
+    with self.cached_session():
       inp = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
       value = array_ops.identity(inp).eval()
     self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value)
 
   def testString(self):
     source = [b"A", b"b", b"C", b"d", b"E", b"f"]
-    with self.test_session():
+    with self.cached_session():
       value = array_ops.identity(source).eval()
     self.assertAllEqual(source, value)
 
   def testIdentityShape(self):
-    with self.test_session():
+    with self.cached_session():
       shape = [2, 3]
       array_2x3 = [[1, 2, 3], [6, 5, 4]]
       tensor = constant_op.constant(array_2x3)
@@ -59,7 +59,7 @@
                         array_ops.identity(np.array(array_2x3)).get_shape())
 
   def testRefIdentityShape(self):
-    with self.test_session():
+    with self.cached_session():
       shape = [2, 3]
       tensor = variables.Variable(
           constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index fafeea8..6fdb497 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -30,7 +30,7 @@
 
   def _validateInTopK(self, predictions, target, k, expected):
     np_ans = np.array(expected)
-    with self.test_session():
+    with self.cached_session():
       precision = nn_ops.in_top_k(predictions, target, k)
       out = precision.eval()
       self.assertAllClose(np_ans, out)
@@ -65,7 +65,7 @@
   def testBadTarget(self):
     predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
     target = [0, 80000]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "target.*out of range"):
         nn_ops.in_top_k(predictions, target, 2).eval()
@@ -75,7 +75,7 @@
     target = [0, 2]
     k = constant_op.constant(3)
     np_ans = np.array([False, True])
-    with self.test_session():
+    with self.cached_session():
       precision = nn_ops.in_top_k(predictions, target, k)
       out = precision.eval()
       self.assertAllClose(np_ans, out)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index f6097ad..79ce965 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -343,7 +343,7 @@
 
   def testZeroSize(self):
     shape = [0, 2]
-    with self.test_session():
+    with self.cached_session():
       x = variable_scope.get_variable(
           "x",
           shape=shape,
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
index 6e89436..90759c2 100644
--- a/tensorflow/python/kernel_tests/inplace_ops_test.py
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -153,7 +153,7 @@
       self.assertAllClose(vy, vz)
 
   def testError(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                    "must be a vector"):
         _ = inplace_ops.inplace_update([[1.]], [[0]], [[10]]).eval()
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index 61944f7..afa2419 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -37,7 +37,7 @@
       with tempfile.NamedTemporaryFile(
           prefix='ReadFileTest', dir=self.get_temp_dir(), delete=False) as temp:
         temp.write(contents)
-      with self.test_session():
+      with self.cached_session():
         read = io_ops.read_file(temp.name)
         self.assertEqual([], read.get_shape())
         self.assertEqual(read.eval(), contents)
@@ -51,7 +51,7 @@
           prefix='WriteFileTest', dir=self.get_temp_dir(),
           delete=False) as temp:
         pass
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         w = io_ops.write_file(temp.name, contents)
         sess.run(w)
         with open(temp.name, 'rb') as f:
@@ -65,7 +65,7 @@
       contents = compat.as_bytes(contents)
       subdir = os.path.join(self.get_temp_dir(), 'subdir1')
       filepath = os.path.join(subdir, 'subdir2', 'filename')
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         w = io_ops.write_file(filepath, contents)
         sess.run(w)
         with open(filepath, 'rb') as f:
@@ -88,7 +88,7 @@
             prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases
     ]
 
-    with self.test_session():
+    with self.cached_session():
       # Test exact match without wildcards.
       for f in files:
         self.assertEqual(
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index f4ec3e3..be2e31c 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -25,6 +25,22 @@
 )
 
 cuda_py_test(
+    name = "linear_operator_addition_test",
+    size = "small",
+    srcs = ["linear_operator_addition_test.py"],
+    additional_deps = [
+        "//tensorflow/python/ops/linalg",
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+cuda_py_test(
     name = "linear_operator_block_diag_test",
     size = "medium",
     srcs = ["linear_operator_block_diag_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
new file mode 100644
index 0000000..7c79fed
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -0,0 +1,412 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_addition
+from tensorflow.python.platform import test
+
+linalg = linalg_lib
+random_seed.set_random_seed(23)
+rng = np.random.RandomState(0)
+
+add_operators = linear_operator_addition.add_operators
+
+
+# pylint: disable=unused-argument
+class _BadAdder(linear_operator_addition._Adder):
+  """Adder that will fail if used."""
+
+  def can_add(self, op1, op2):
+    raise AssertionError("BadAdder.can_add called!")
+
+  def _add(self, op1, op2, operator_name, hints):
+    raise AssertionError("This line should not be reached")
+
+
+# pylint: enable=unused-argument
+
+
+class LinearOperatorAdditionCorrectnessTest(test.TestCase):
+  """Tests correctness of addition with combinations of a few Adders.
+
+  Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
+  add_operators should reduce all operators resulting in one single operator.
+
+  This shows that we are able to correctly combine adders using the tiered
+  system.  All Adders should be tested separately, and there is no need to test
+  every Adder within this class.
+  """
+
+  def test_one_operator_is_returned_unchanged(self):
+    op_a = linalg.LinearOperatorDiag([1., 1.])
+    op_sum = add_operators([op_a])
+    self.assertEqual(1, len(op_sum))
+    self.assertIs(op_sum[0], op_a)
+
+  def test_at_least_one_operators_required(self):
+    with self.assertRaisesRegexp(ValueError, "must contain at least one"):
+      add_operators([])
+
+  def test_attempting_to_add_numbers_raises(self):
+    with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
+      add_operators([1, 2])
+
+  def test_two_diag_operators(self):
+    op_a = linalg.LinearOperatorDiag(
+        [1., 1.], is_positive_definite=True, name="A")
+    op_b = linalg.LinearOperatorDiag(
+        [2., 2.], is_positive_definite=True, name="B")
+    with self.test_session():
+      op_sum = add_operators([op_a, op_b])
+      self.assertEqual(1, len(op_sum))
+      op = op_sum[0]
+      self.assertIsInstance(op, linalg_lib.LinearOperatorDiag)
+      self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
+      # Adding positive definite operators produces positive def.
+      self.assertTrue(op.is_positive_definite)
+      # Real diagonal ==> self-adjoint.
+      self.assertTrue(op.is_self_adjoint)
+      # Positive definite ==> non-singular
+      self.assertTrue(op.is_non_singular)
+      # Enforce particular name for this simple case
+      self.assertEqual("Add/B__A/", op.name)
+
+  def test_three_diag_operators(self):
+    op1 = linalg.LinearOperatorDiag(
+        [1., 1.], is_positive_definite=True, name="op1")
+    op2 = linalg.LinearOperatorDiag(
+        [2., 2.], is_positive_definite=True, name="op2")
+    op3 = linalg.LinearOperatorDiag(
+        [3., 3.], is_positive_definite=True, name="op3")
+    with self.test_session():
+      op_sum = add_operators([op1, op2, op3])
+      self.assertEqual(1, len(op_sum))
+      op = op_sum[0]
+      self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
+      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+      # Adding positive definite operators produces positive def.
+      self.assertTrue(op.is_positive_definite)
+      # Real diagonal ==> self-adjoint.
+      self.assertTrue(op.is_self_adjoint)
+      # Positive definite ==> non-singular
+      self.assertTrue(op.is_non_singular)
+
+  def test_diag_tril_diag(self):
+    op1 = linalg.LinearOperatorDiag(
+        [1., 1.], is_non_singular=True, name="diag_a")
+    op2 = linalg.LinearOperatorLowerTriangular(
+        [[2., 0.], [0., 2.]],
+        is_self_adjoint=True,
+        is_non_singular=True,
+        name="tril")
+    op3 = linalg.LinearOperatorDiag(
+        [3., 3.], is_non_singular=True, name="diag_b")
+    with self.test_session():
+      op_sum = add_operators([op1, op2, op3])
+      self.assertEqual(1, len(op_sum))
+      op = op_sum[0]
+      self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular)
+      self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+
+      # The diag operators will be self-adjoint (because real and diagonal).
+      # The TriL operator has the self-adjoint hint set.
+      self.assertTrue(op.is_self_adjoint)
+
+      # Even though op1/2/3 are non-singular, this does not imply op is.
+      # Since no custom hint was provided, we default to None (unknown).
+      self.assertEqual(None, op.is_non_singular)
+
+  def test_matrix_diag_tril_diag_uses_custom_name(self):
+    op0 = linalg.LinearOperatorFullMatrix(
+        [[-1., -1.], [-1., -1.]], name="matrix")
+    op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
+    op2 = linalg.LinearOperatorLowerTriangular(
+        [[2., 0.], [1.5, 2.]], name="tril")
+    op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
+    with self.test_session():
+      op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
+      self.assertEqual(1, len(op_sum))
+      op = op_sum[0]
+      self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix)
+      self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
+      self.assertEqual("my_operator", op.name)
+
+  def test_incompatible_domain_dimensions_raises(self):
+    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+    op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
+    with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
+      add_operators([op1, op2])
+
+  def test_incompatible_range_dimensions_raises(self):
+    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+    op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
+    with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
+      add_operators([op1, op2])
+
+  def test_non_broadcastable_batch_shape_raises(self):
+    op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
+    op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
+    with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
+      add_operators([op1, op2])
+
+
+class LinearOperatorOrderOfAdditionTest(test.TestCase):
+  """Test that the order of addition is done as specified by tiers."""
+
+  def test_tier_0_additions_done_in_tier_0(self):
+    diag1 = linalg.LinearOperatorDiag([1.])
+    diag2 = linalg.LinearOperatorDiag([1.])
+    diag3 = linalg.LinearOperatorDiag([1.])
+    addition_tiers = [
+        [linear_operator_addition._AddAndReturnDiag()],
+        [_BadAdder()],
+    ]
+    # Should not raise since all were added in tier 0, and tier 1 (with the
+    # _BadAdder) was never reached.
+    op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
+    self.assertEqual(1, len(op_sum))
+    self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
+
+  def test_tier_1_additions_done_by_tier_1(self):
+    diag1 = linalg.LinearOperatorDiag([1.])
+    diag2 = linalg.LinearOperatorDiag([1.])
+    tril = linalg.LinearOperatorLowerTriangular([[1.]])
+    addition_tiers = [
+        [linear_operator_addition._AddAndReturnDiag()],
+        [linear_operator_addition._AddAndReturnTriL()],
+        [_BadAdder()],
+    ]
+    # Should not raise since all were added by tier 1, and the
+    # _BadAdder) was never reached.
+    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+    self.assertEqual(1, len(op_sum))
+    self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+  def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
+    diag1 = linalg.LinearOperatorDiag([1.])
+    diag2 = linalg.LinearOperatorDiag([1.])
+    tril = linalg.LinearOperatorLowerTriangular([[1.]])
+    addition_tiers = [
+        [linear_operator_addition._AddAndReturnTriL()],
+        [linear_operator_addition._AddAndReturnDiag()],
+        [_BadAdder()],
+    ]
+    # Tier 0 could convert to TriL, and this converted everything to TriL,
+    # including the Diags.
+    # Tier 1 was never used.
+    # Tier 2 was never used (therefore, _BadAdder didn't raise).
+    op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+    self.assertEqual(1, len(op_sum))
+    self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+  def test_cannot_add_everything_so_return_more_than_one_operator(self):
+    diag1 = linalg.LinearOperatorDiag([1.])
+    diag2 = linalg.LinearOperatorDiag([2.])
+    tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
+    addition_tiers = [
+        [linear_operator_addition._AddAndReturnDiag()],
+    ]
+    # Tier 0 (the only tier) can only convert to Diag, so it combines the two
+    # diags, but the TriL is unchanged.
+    # Result should contain two operators, one Diag, one TriL.
+    op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
+    self.assertEqual(2, len(op_sum))
+    found_diag = False
+    found_tril = False
+    with self.test_session():
+      for op in op_sum:
+        if isinstance(op, linalg.LinearOperatorDiag):
+          found_diag = True
+          self.assertAllClose([[3.]], op.to_dense().eval())
+        if isinstance(op, linalg.LinearOperatorLowerTriangular):
+          found_tril = True
+          self.assertAllClose([[5.]], op.to_dense().eval())
+      self.assertTrue(found_diag and found_tril)
+
+  def test_intermediate_tier_is_not_skipped(self):
+    diag1 = linalg.LinearOperatorDiag([1.])
+    diag2 = linalg.LinearOperatorDiag([1.])
+    tril = linalg.LinearOperatorLowerTriangular([[1.]])
+    addition_tiers = [
+        [linear_operator_addition._AddAndReturnDiag()],
+        [_BadAdder()],
+        [linear_operator_addition._AddAndReturnTriL()],
+    ]
+    # tril cannot be added in tier 0, and the intermediate tier 1 with the
+    # BadAdder will catch it and raise.
+    with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
+      add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+
+
+class AddAndReturnScaledIdentityTest(test.TestCase):
+
+  def setUp(self):
+    self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
+
+  def test_identity_plus_identity(self):
+    id1 = linalg.LinearOperatorIdentity(num_rows=2)
+    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(id1, id2))
+    operator = self._adder.add(id1, id2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+    with self.test_session():
+      self.assertAllClose(2 *
+                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+                          operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+  def test_identity_plus_scaled_identity(self):
+    id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(id1, id2))
+    operator = self._adder.add(id1, id2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+    with self.test_session():
+      self.assertAllClose(3.2 *
+                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+                          operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+  def test_scaled_identity_plus_scaled_identity(self):
+    id1 = linalg.LinearOperatorScaledIdentity(
+        num_rows=2, multiplier=[2.2, 2.2, 2.2])
+    id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(id1, id2))
+    operator = self._adder.add(id1, id2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+    with self.test_session():
+      self.assertAllClose(1.2 *
+                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+                          operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnDiagTest(test.TestCase):
+
+  def setUp(self):
+    self._adder = linear_operator_addition._AddAndReturnDiag()
+
+  def test_identity_plus_identity_returns_diag(self):
+    id1 = linalg.LinearOperatorIdentity(num_rows=2)
+    id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(id1, id2))
+    operator = self._adder.add(id1, id2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+    with self.test_session():
+      self.assertAllClose(2 *
+                          linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+                          operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+  def test_diag_plus_diag(self):
+    diag1 = rng.rand(2, 3, 4)
+    diag2 = rng.rand(4)
+    op1 = linalg.LinearOperatorDiag(diag1)
+    op2 = linalg.LinearOperatorDiag(diag2)
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(op1, op2))
+    operator = self._adder.add(op1, op2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+    with self.test_session():
+      self.assertAllClose(
+          linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
+          operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnTriLTest(test.TestCase):
+
+  def setUp(self):
+    self._adder = linear_operator_addition._AddAndReturnTriL()
+
+  def test_diag_plus_tril(self):
+    diag = linalg.LinearOperatorDiag([1., 2.])
+    tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=True, is_non_singular=True)
+
+    self.assertTrue(self._adder.can_add(diag, diag))
+    self.assertTrue(self._adder.can_add(diag, tril))
+    operator = self._adder.add(diag, tril, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
+
+    with self.test_session():
+      self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
+    self.assertTrue(operator.is_positive_definite)
+    self.assertTrue(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnMatrixTest(test.TestCase):
+
+  def setUp(self):
+    self._adder = linear_operator_addition._AddAndReturnMatrix()
+
+  def test_diag_plus_diag(self):
+    diag1 = linalg.LinearOperatorDiag([1., 2.])
+    diag2 = linalg.LinearOperatorDiag([-1., 3.])
+    hints = linear_operator_addition._Hints(
+        is_positive_definite=False, is_non_singular=False)
+
+    self.assertTrue(self._adder.can_add(diag1, diag2))
+    operator = self._adder.add(diag1, diag2, "my_operator", hints)
+    self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
+
+    with self.test_session():
+      self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
+    self.assertFalse(operator.is_positive_definite)
+    self.assertFalse(operator.is_non_singular)
+    self.assertEqual("my_operator", operator.name)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 0e4e584..cd6a34d 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -40,7 +40,7 @@
 class ShapeTest(test_lib.TestCase):
 
   def testBatchGradientUnknownSize(self):
-    with self.test_session():
+    with self.cached_session():
       batch_size = constant_op.constant(3)
       matrix_size = constant_op.constant(4)
       batch_identity = array_ops.tile(
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 2f28d37..aa17f72 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -128,7 +128,7 @@
       matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
                                                        6 + 6j]]).astype(dtype)
       expected_transposed = np.conj(matrix_np.T)
-      with self.test_session():
+      with self.cached_session():
         matrix = ops.convert_to_tensor(matrix_np)
         transposed = linalg.adjoint(matrix)
         self.assertEqual((3, 2), transposed.get_shape())
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 9b6aee6..0f56077 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,9 +170,8 @@
             list_ops.tensor_list_pop_back(
                 l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
 
-  @test_util.run_in_graph_and_eager_modes
   def testGraphStack(self):
-    with context.graph_mode(), self.test_session():
+    with self.cached_session():
       tl = list_ops.empty_tensor_list(
           element_shape=constant_op.constant([1], dtype=dtypes.int32),
           element_dtype=dtypes.int32)
@@ -182,9 +181,8 @@
               list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
           [[1]])
 
-  @test_util.run_in_graph_and_eager_modes
   def testGraphStackInLoop(self):
-    with context.graph_mode(), self.test_session():
+    with self.cached_session():
       t1 = list_ops.empty_tensor_list(
           element_shape=constant_op.constant([], dtype=dtypes.int32),
           element_dtype=dtypes.int32)
@@ -200,9 +198,8 @@
       s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
       self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
 
-  @test_util.run_in_graph_and_eager_modes
   def testGraphStackSwitchDtype(self):
-    with context.graph_mode(), self.test_session():
+    with self.cached_session():
       list_ = list_ops.empty_tensor_list(
           element_shape=constant_op.constant([], dtype=dtypes.int32),
           element_dtype=dtypes.int32)
@@ -222,9 +219,8 @@
       np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
       self.assertAllEqual(self.evaluate(s1), np_s1)
 
-  @test_util.run_in_graph_and_eager_modes
   def testGraphStackInLoopSwitchDtype(self):
-    with context.graph_mode(), self.test_session():
+    with self.cached_session():
       t1 = list_ops.empty_tensor_list(
           element_shape=constant_op.constant([], dtype=dtypes.int32),
           element_dtype=dtypes.int32)
@@ -476,6 +472,47 @@
           self.evaluate(t_full_zeros), np.zeros(
               (2,), dtype=dtype.as_numpy_dtype))
 
+  @test_util.run_in_graph_and_eager_modes
+  def testZerosLikeVariant(self):
+    for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16,
+                  dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32,
+                  dtypes.float64, dtypes.complex64, dtypes.complex128,
+                  dtypes.bool):
+      l = list_ops.empty_tensor_list(
+          element_dtype=dtypes.variant, element_shape=scalar_shape())
+
+      sub_l = list_ops.empty_tensor_list(
+          element_dtype=dtype, element_shape=scalar_shape())
+      l = list_ops.tensor_list_push_back(l, sub_l)
+      sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+          1, dtype=dtype))
+      l = list_ops.tensor_list_push_back(l, sub_l)
+      sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast(
+          2, dtype=dtype))
+      l = list_ops.tensor_list_push_back(l, sub_l)
+
+      # l : [[],
+      #      [1],
+      #      [1, 2]]
+      #
+      # l_zeros : [[],
+      #            [0],
+      #            [0, 0]]
+      l_zeros = array_ops.zeros_like(l)
+
+      outputs = []
+      for _ in range(3):
+        l_zeros, out = list_ops.tensor_list_pop_back(
+            l_zeros, element_dtype=dtypes.variant)
+        outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype))
+
+      # Note: `outputs` contains popped values so the order is reversed.
+      self.assertAllEqual(self.evaluate(outputs[2]), [])
+      self.assertAllEqual(
+          self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype))
+      self.assertAllEqual(
+          self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype))
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index ee86cf0..baeb40d 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -42,7 +42,7 @@
         out = [compat.as_bytes(str(a)) for a in out]
       for diff_func in [array_ops.setdiff1d]:
         for index_dtype in [dtypes.int32, dtypes.int64]:
-          with self.test_session() as sess:
+          with self.cached_session() as sess:
             x_tensor = ops.convert_to_tensor(x, dtype=dtype)
             y_tensor = ops.convert_to_tensor(y, dtype=dtype)
             out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index e635a71..82729b9 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -31,7 +31,7 @@
 class LoggingOpsTest(test.TestCase):
 
   def testAssertDivideByZero(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       epsilon = ops.convert_to_tensor(1e-20)
       x = ops.convert_to_tensor(0.0)
       y = ops.convert_to_tensor(1.0)
@@ -66,7 +66,7 @@
     self.assertEqual(inp.get_shape(), inp_printed.get_shape())
 
   def testPrintGradient(self):
-    with self.test_session():
+    with self.cached_session():
       inp = constant_op.constant(2.0, shape=[100, 32], name="in")
       w = constant_op.constant(4.0, shape=[10, 100], name="w")
       wx = math_ops.matmul(w, inp, name="wx")
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 5f08339..38b14e3 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -36,7 +36,7 @@
 class HashTableOpTest(test.TestCase):
 
   def testHashTable(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -54,7 +54,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testHashTableFindHighRank(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -72,7 +72,7 @@
       self.assertAllEqual([[0, 1], [-1, -1]], result)
 
   def testHashTableInitWithPythonArrays(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = ["brain", "salad", "surgery"]
       values = [0, 1, 2]
@@ -90,7 +90,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testHashTableInitWithNumPyArrays(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
       values = np.array([0, 1, 2], dtype=np.int64)
@@ -107,7 +107,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testMultipleHashTables(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -135,7 +135,7 @@
       self.assertAllEqual([0, 1, -1], out3)
 
   def testHashTableWithTensorDefault(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = constant_op.constant(-1, dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@
       self.assertAllEqual([0, 1, -1], result)
 
   def testHashTableWithSparseTensorInput(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_val = constant_op.constant(-1, dtypes.int64)
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -173,7 +173,7 @@
       self.assertAllEqual(sp_shape, out_shape)
 
   def testSignatureMismatch(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -190,7 +190,7 @@
             lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
 
   def testDTypes(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       with self.assertRaises(TypeError):
         lookup_ops.HashTable(
@@ -198,7 +198,7 @@
                                                  dtypes.int64), default_val)
 
   def testNotInitialized(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       table = lookup_ops.HashTable(
           lookup_ops.KeyValueTensorInitializer(
@@ -211,7 +211,7 @@
         output.eval()
 
   def testInitializeTwice(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -223,7 +223,7 @@
         table.init.run()
 
   def testInitializationWithInvalidDimensions(self):
-    with self.test_session():
+    with self.cached_session():
       default_val = -1
       keys = constant_op.constant(["brain", "salad", "surgery"])
       values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -272,7 +272,7 @@
 
   def test_string_index_table_from_file(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -284,7 +284,7 @@
   def test_string_index_table_from_multicolumn_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file,
           num_oov_buckets=1,
@@ -299,7 +299,7 @@
   def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file,
           num_oov_buckets=1,
@@ -314,7 +314,7 @@
 
   def test_string_index_table_from_file_tensor_filename(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       vocabulary_file = constant_op.constant(vocabulary_file)
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -328,7 +328,7 @@
 
   def test_string_index_table_from_file_placeholder_filename(self):
     vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
-    with self.test_session():
+    with self.cached_session():
       vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -344,7 +344,7 @@
   def test_int32_index_table_from_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab2.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file,
           num_oov_buckets=1,
@@ -359,7 +359,7 @@
   def test_int64_index_table_from_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab3.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file,
           num_oov_buckets=1,
@@ -374,7 +374,7 @@
   def test_index_table_from_file_with_default_value(self):
     default_value = -42
     vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, default_value=default_value)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -385,7 +385,7 @@
 
   def test_index_table_from_file_with_oov_buckets(self):
     vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1000)
       ids = table.lookup(
@@ -432,7 +432,7 @@
 
   def test_index_table_from_file_with_vocab_size_too_small(self):
     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=2)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -444,7 +444,7 @@
 
   def test_index_table_from_file_with_vocab_size_too_large(self):
     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=4)
       self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -459,7 +459,7 @@
         vocabulary_file=vocabulary_file,
         vocab_size=0)
 
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=3)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -471,7 +471,7 @@
 
   def test_index_table_from_file_with_invalid_hashers(self):
     vocabulary_file = self._createVocabFile("invalid_hasher.txt")
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         lookup_ops.index_table_from_file(
             vocabulary_file=vocabulary_file,
@@ -490,14 +490,14 @@
 
   def test_index_table_from_file_table_ref_with_oov_buckets(self):
     vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=1)
       self.assertIsNotNone(table.table_ref)
 
   def test_index_table_from_file_table_ref_without_oov_buckets(self):
     vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_file(
           vocabulary_file=vocabulary_file, num_oov_buckets=0)
       self.assertIsNotNone(table.table_ref)
@@ -506,21 +506,21 @@
 class KeyValueTensorInitializerTest(test.TestCase):
 
   def test_string(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup_ops.KeyValueTensorInitializer(
           ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
       table = lookup_ops.HashTable(init, default_value=-1)
       table.init.run()
 
   def test_int64(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
                                                   dtypes.int64, dtypes.int64)
       table = lookup_ops.HashTable(init, default_value=-1)
       table.init.run()
 
   def test_int32(self):
-    with ops.Graph().as_default(), self.test_session():
+    with ops.Graph().as_default(), self.cached_session():
       init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
                                                   dtypes.int32, dtypes.int64)
       table = lookup_ops.HashTable(init, default_value=-1)
@@ -532,7 +532,7 @@
 class IndexTableFromTensor(test.TestCase):
 
   def test_index_table_from_tensor_with_tensor_init(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_tensor(
           vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
       ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
@@ -542,7 +542,7 @@
       self.assertAllEqual((1, 2, 3), ids.eval())
 
   def test_int32_index_table_from_tensor_with_tensor_init(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_tensor(
           vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
       ids = table.lookup(
@@ -553,7 +553,7 @@
       self.assertAllEqual((1, 2, 3), ids.eval())
 
   def test_int64_index_table_from_tensor_with_tensor_init(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_tensor(
           vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
       ids = table.lookup(
@@ -565,7 +565,7 @@
 
   def test_index_table_from_tensor_with_default_value(self):
     default_value = -42
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_tensor(
           vocabulary_list=["brain", "salad", "surgery"],
           default_value=default_value)
@@ -576,14 +576,14 @@
       self.assertAllEqual((1, 2, default_value), ids.eval())
 
   def test_index_table_from_tensor_missing_vocabulary_list(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError,
                                    "vocabulary_list must be specified"):
         lookup_ops.index_table_from_tensor(
             vocabulary_list=None, num_oov_buckets=1)
 
   def test_index_table_from_tensor_empty_vocabulary_list(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_table_from_tensor(
           vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
       ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -593,7 +593,7 @@
         lookup_ops.tables_initializer().run()
 
   def test_index_table_from_tensor_with_invalid_hashers(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         lookup_ops.index_table_from_tensor(
             vocabulary_list=["brain", "salad", "surgery"],
@@ -623,7 +623,7 @@
     type_funcs = [str, constant_op.constant]
     for type_func in type_funcs:
       vocabulary_file = type_func(vocabulary_path)
-      with self.test_session():
+      with self.cached_session():
         table = lookup_ops.index_to_string_table_from_file(
             vocabulary_file=vocabulary_file)
         features = table.lookup(
@@ -636,7 +636,7 @@
   def test_index_to_string_table_from_multicolumn_file(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file,
           key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -650,7 +650,7 @@
   def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
     vocabulary_file = self._createVocabFile(
         "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file,
           key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -665,7 +665,7 @@
   def test_index_to_string_table_with_default_value(self):
     default_value = b"NONE"
     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, default_value=default_value)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -677,7 +677,7 @@
   def test_index_to_string_table_with_vocab_size_too_small(self):
     default_value = b"NONE"
     vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file,
           vocab_size=2,
@@ -690,7 +690,7 @@
 
   def test_index_to_string_table_with_vocab_size_too_large(self):
     vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=4)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -702,7 +702,7 @@
 
   def test_index_to_string_table_with_vocab_size(self):
     vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.index_to_string_table_from_file(
           vocabulary_file=vocabulary_file, vocab_size=3)
       features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -715,7 +715,7 @@
 class IndexToStringTableFromTensorTest(test.TestCase):
 
   def test_index_to_string_table_from_tensor(self):
-    with self.test_session():
+    with self.cached_session():
       vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
       table = lookup_ops.index_to_string_table_from_tensor(
           vocabulary_list=vocabulary_list)
@@ -729,7 +729,7 @@
                           features.eval())
 
   def test_duplicate_entries(self):
-    with self.test_session():
+    with self.cached_session():
       vocabulary_list = constant_op.constant(["hello", "hello"])
       table = lookup_ops.index_to_string_table_from_tensor(
           vocabulary_list=vocabulary_list)
@@ -740,7 +740,7 @@
 
   def test_index_to_string_with_default_value(self):
     default_value = b"NONE"
-    with self.test_session():
+    with self.cached_session():
       vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
       table = lookup_ops.index_to_string_table_from_tensor(
           vocabulary_list=vocabulary_list, default_value=default_value)
@@ -764,7 +764,7 @@
   def testInitializeStringTable(self):
     vocabulary_file = self._createVocabFile("one_column_1.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       table = lookup_ops.HashTable(
           lookup_ops.TextFileInitializer(
@@ -782,7 +782,7 @@
     vocabulary_file = self._createVocabFile(
         "one_column_int64.txt", values=("42", "1", "-1000"))
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       table = lookup_ops.HashTable(
           lookup_ops.TextFileInitializer(
@@ -800,7 +800,7 @@
   def testInitializeIndexTable(self):
     vocabulary_file = self._createVocabFile("one_column_2.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       key_index = lookup_ops.TextFileIndex.LINE_NUMBER
       value_index = lookup_ops.TextFileIndex.WHOLE_LINE
@@ -821,7 +821,7 @@
     with open(vocabulary_file, "w") as f:
       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 1
       value_index = 2
@@ -843,7 +843,7 @@
     with open(vocabulary_file, "w") as f:
       f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 2
       value_index = 1
@@ -857,7 +857,7 @@
   def testInvalidDataType(self):
     vocabulary_file = self._createVocabFile("one_column_3.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       key_index = lookup_ops.TextFileIndex.WHOLE_LINE
       value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -870,7 +870,7 @@
 
   def testInvalidIndex(self):
     vocabulary_file = self._createVocabFile("one_column_4.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       key_index = 1  # second column of the line
       value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -885,7 +885,7 @@
   def testInitializeSameTableWithMultipleNodes(self):
     vocabulary_file = self._createVocabFile("one_column_5.txt")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       shared_name = "shared-one-columm"
       default_value = -1
       table1 = lookup_ops.HashTable(
@@ -924,7 +924,7 @@
       self.assertAllEqual([0, 1, -1], out3)
 
   def testInitializeTableWithNoFilename(self):
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       with self.assertRaises(ValueError):
         lookup_ops.HashTable(
@@ -934,7 +934,7 @@
             default_value)
 
   def testInitializeWithVocabSize(self):
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -982,7 +982,7 @@
   def testFeedVocabularyName(self):
     vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       table = lookup_ops.HashTable(
           lookup_ops.TextFileInitializer(
@@ -1008,7 +1008,7 @@
   def testInvalidFilenames(self):
     vocabulary_file = self._createVocabFile("filename_shape.txt")
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
 
       # Invalid data type
@@ -1031,7 +1031,7 @@
 
   def testIdToStringTable(self):
     vocab_file = self._createVocabFile("feat_to_id_1.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = "UNK"
       vocab_size = 3
       table = lookup_ops.HashTable(
@@ -1048,7 +1048,7 @@
 
   def testStringToIdTable(self):
     vocab_file = self._createVocabFile("feat_to_id_2.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       table = lookup_ops.HashTable(
@@ -1065,7 +1065,7 @@
   def testInt64ToIdTable(self):
     vocab_file = self._createVocabFile(
         "feat_to_id_3.txt", values=("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       table = lookup_ops.HashTable(
@@ -1090,7 +1090,7 @@
 
   def testStringIdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_1.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1110,7 +1110,7 @@
 
   def testInt32IdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1132,7 +1132,7 @@
 
   def testInt64IdTableWithHashBuckets(self):
     vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1151,7 +1151,7 @@
       self.assertEquals(vocab_size + oov_buckets, table.size().eval())
 
   def testStringIdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       oov_buckets = 5
 
       # Set a table that only uses hash buckets, for each input value returns
@@ -1172,7 +1172,7 @@
       self.assertEquals(oov_buckets, table.size().eval())
 
   def testInt32IdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       oov_buckets = 5
 
       # Set a table that only uses hash buckets, for each input value returns
@@ -1194,20 +1194,20 @@
       self.assertEquals(oov_buckets, table.size().eval())
 
   def testFloat64IdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
         lookup_ops.IdTableWithHashBuckets(
             None, num_oov_buckets=5, key_dtype=dtypes.float64)
 
   def testBoolIdTableWithOnlyHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
         lookup_ops.IdTableWithHashBuckets(
             None, num_oov_buckets=5, key_dtype=dtypes.bool)
 
   def testIdTableWithHashBucketsWithMultipleInitializers(self):
     vocab_file = self._createVocabFile("feat_to_id_4.txt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_value = -1
       vocab_size = 3
       oov_buckets = 3
@@ -1248,7 +1248,7 @@
   def testIdTableWithHashBucketsInitializationAcrossSessions(self):
     vocab_file = self._createVocabFile("feat_to_id_5.txt")
     shared_name = "across-sessions"
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1269,7 +1269,7 @@
       self.assertAllEqual([0, 1, 2, 3], out1.eval())
       self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
 
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1292,7 +1292,7 @@
 
   def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
     vocab_file = self._createVocabFile("feat_to_id_6.txt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       default_value1 = -1
       vocab_size = 3
       oov_buckets = 0
@@ -1328,7 +1328,7 @@
     vocab_file = self._createVocabFile("feat_to_id_7.txt")
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -1355,7 +1355,7 @@
   def testInt32SparseTensor(self):
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -1383,7 +1383,7 @@
   def testInt64SparseTensor(self):
     input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
     input_shape = [4, 4]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sp_features = sparse_tensor.SparseTensor(
           constant_op.constant(input_indices, dtypes.int64),
           constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -1410,7 +1410,7 @@
 
   def testIdTableWithHashBucketsWithInvalidHashers(self):
     vocab_file = self._createVocabFile("feat_to_id_4.txt")
-    with self.test_session():
+    with self.cached_session():
       default_value = -1
       vocab_size = 3
       oov_buckets = 1
@@ -1451,7 +1451,7 @@
             hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
 
   def testIdTableWithHashBucketsNoInnerTable(self):
-    with self.test_session():
+    with self.cached_session():
       table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
       self.assertIsNone(table.table_ref)
 
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 87fc715..3ce0b74 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -61,62 +61,62 @@
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.absolute_difference(
             self._predictions, self._predictions, weights=None)
 
   def testAllCorrectNoLossWeight(self):
     loss = losses.absolute_difference(self._predictions, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testNonZeroLoss(self):
     loss = losses.absolute_difference(self._labels, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5, loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithScalarTensorWeight(self):
     weights = 2.3
     loss = losses.absolute_difference(self._labels, self._predictions,
                                       constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeights(self):
     weights = constant_op.constant((1.2, 0.0), shape=(2, 1))
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.6, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(5.6, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeights(self):
     weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(16.6, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
     weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(6.0, loss.eval(), 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     weights = array_ops.zeros((2, 3))
     loss = losses.absolute_difference(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
@@ -134,12 +134,12 @@
     logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.softmax_cross_entropy(labels, logits, weights=None)
 
   def testAllCorrect(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
       labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
@@ -152,7 +152,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
 
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits)
       self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -162,7 +162,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
 
@@ -171,7 +171,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits,
                                           constant_op.constant(weights))
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -181,7 +181,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
     weights = constant_op.constant((1.2, 3.4, 5.6))
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -190,7 +190,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
     weights = constant_op.constant([0, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
@@ -199,12 +199,12 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
     weights = constant_op.constant([1.2, 0, 0], shape=[3])
-    with self.test_session():
+    with self.cached_session():
       loss = losses.softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(12.0, loss.eval(), 3)
 
   def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -215,7 +215,7 @@
         losses.softmax_cross_entropy(labels, logits, weights=weights).eval()
 
   def testSoftmaxLabelSmoothing(self):
-    with self.test_session():
+    with self.cached_session():
       # Softmax Cross Entropy Loss is:
       #   -\sum_i p_i \log q_i
       # where for a softmax activation
@@ -242,12 +242,12 @@
     logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[0], [1], [2]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.sparse_softmax_cross_entropy(labels, logits, weights=None)
 
   def testAllCorrectInt32Labels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
       labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
@@ -263,7 +263,7 @@
     losses.sparse_softmax_cross_entropy(labels, logits)
 
   def testAllCorrectInt64Labels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
       labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64)
@@ -272,7 +272,7 @@
       self.assertAlmostEqual(loss.eval(), 0.0, 3)
 
   def testAllCorrectNonColumnLabels(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
                                      [0.0, 0.0, 10.0]])
       labels = constant_op.constant([0, 1, 2])
@@ -285,7 +285,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
 
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -295,7 +295,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
 
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -305,7 +305,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([2, 0, 1])
 
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits)
       self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
       self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -315,7 +315,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
 
@@ -324,7 +324,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits,
                                                  constant_op.constant(weights))
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -334,7 +334,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = 2.3
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(
           labels, logits, constant_op.constant((weights,)))
       self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -345,7 +345,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = array_ops.placeholder(dtypes.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       loss_val = sess.run(loss,
                           feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
@@ -355,7 +355,7 @@
     logits = array_ops.placeholder(dtypes.float32)
     labels = array_ops.placeholder(dtypes.int32)
     weights = 1.0
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       loss_val = sess.run(loss,
                           feed_dict={
@@ -370,7 +370,7 @@
     logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
     labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
     weights = array_ops.placeholder(dtypes.float32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       loss_val = sess.run(loss,
                           feed_dict={
@@ -387,7 +387,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1))
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -396,7 +396,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([[1.2], [3.4], [5.6]])
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
 
@@ -405,7 +405,7 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([0, 0, 0], shape=(3, 1))
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
@@ -414,12 +414,12 @@
                                    [0.0, 0.0, 10.0]])
     labels = constant_op.constant([[2], [0], [1]])
     weights = constant_op.constant([1.2, 0, 0], shape=(3, 1))
-    with self.test_session():
+    with self.cached_session():
       loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
       self.assertAlmostEqual(12.0, loss.eval(), 3)
 
   def testMeasurementSpecificWeightsRaisesException(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -432,7 +432,7 @@
 
   def testInconsistentWeightSizeRaisesException(self):
     """The weight tensor has incorrect number of elements."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -445,7 +445,7 @@
 
   def testInconsistentLabelSizeRaisesException(self):
     """The label tensor has incorrect number of elements."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -458,7 +458,7 @@
 
   def testInconsistentWeightShapeRaisesException(self):
     """The weight tensor has incorrect shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0, -100.0],
                                      [-100.0, -100.0, 100.0, -100.0],
@@ -472,7 +472,7 @@
 
   def testInconsistentLabelShapeRaisesException(self):
     """The label tensor has incorrect shape."""
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0, -100.0],
                                      [-100.0, -100.0, 100.0, -100.0],
@@ -488,7 +488,7 @@
 class SigmoidCrossEntropyLossTest(test.TestCase):
 
   def testAllCorrectSigmoid(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -506,7 +506,7 @@
     loss = losses.sigmoid_cross_entropy(labels, logits, weights)
     self.assertEquals(logits.dtype, loss.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: np.ones((32, 1)),
@@ -522,7 +522,7 @@
     loss = losses.sigmoid_cross_entropy(labels, logits, weights)
     self.assertEquals(logits.dtype, loss.dtype)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss,
                       feed_dict={
                           logits: np.ones((32, 2)),
@@ -531,7 +531,7 @@
       self.assertAlmostEqual(0.313, loss, 3)
 
   def testAllWrongSigmoid(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -542,7 +542,7 @@
       self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
 
   def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0],
                                      [-100.0, 100.0, -100.0],
                                      [-100.0, -100.0, 100.0]])
@@ -562,7 +562,7 @@
     self.assertEquals(logits.dtype, loss.dtype)
     self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testSigmoidFloat64(self):
@@ -577,7 +577,7 @@
     loss = losses.sigmoid_cross_entropy(labels, logits)
     self.assertEquals(logits.dtype, loss.dtype)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(44.444, loss.eval(), 3)
 
   def testSigmoidNoReduction(self):
@@ -590,7 +590,7 @@
         labels, logits, reduction=losses.Reduction.NONE)
     self.assertEquals(logits.dtype, loss.dtype)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose((
           (0., 0., 0.),
           (0., 100., 100.),
@@ -598,7 +598,7 @@
       ), loss.eval(), 3)
 
   def testSigmoidLabelSmoothingCorrect(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[100.0, -100.0, -100.0]])
       labels = constant_op.constant([[1, 0, 1]])
       # Sigmoid cross entropy loss is:
@@ -621,7 +621,7 @@
       self.assertAlmostEqual(loss.eval(), expected_value, 3)
 
   def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
-    with self.test_session():
+    with self.cached_session():
       label_smoothing = 0.1
       sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
       sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -656,33 +656,33 @@
     self._labels = constant_op.constant(labels)
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.log_loss(self._labels, self._labels, weights=None)
 
   def testAllCorrectNoLossWeight(self):
     loss = losses.log_loss(self._labels, self._labels)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testAllCorrectNoLossWeightWithPlaceholder(self):
     tf_predictions = array_ops.placeholder(
         dtypes.float32, shape=self._np_labels.shape)
     loss = losses.log_loss(self._labels, tf_predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(
           0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
 
   def testNonZeroLoss(self):
     loss = losses.log_loss(self._labels, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = losses.log_loss(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
@@ -690,7 +690,7 @@
     weights = 2.3
     loss = losses.log_loss(self._labels, self._predictions,
                            constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss.eval(), 3)
 
@@ -700,7 +700,7 @@
     weights = 2.3
     loss = losses.log_loss(self._labels, tf_predictions,
                            constant_op.constant(weights))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss, 3)
@@ -710,7 +710,7 @@
     weights = 2.3
     loss = losses.log_loss(self._labels, tf_predictions,
                            constant_op.constant(weights))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
                              loss, 3)
@@ -721,7 +721,7 @@
         self._expected_losses,
         np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
     loss = losses.log_loss(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -730,7 +730,7 @@
                                   np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
                                       (2, 3)))
     loss = losses.log_loss(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -739,12 +739,12 @@
                                   np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
                                       (2, 3)))
     loss = losses.log_loss(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
 
   def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
     weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.log_loss(self._labels, self._predictions, weights)
 
@@ -757,7 +757,7 @@
         self._predictions,
         constant_op.constant(
             weights, shape=(2, 3)))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
 
   def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -771,7 +771,7 @@
         constant_op.constant(
             weights, shape=(2, 3)))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
 
@@ -784,7 +784,7 @@
         self._predictions,
         constant_op.constant(
             weights, shape=(2, 3)))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -795,35 +795,35 @@
     tf_weights = constant_op.constant(weights, shape=(2, 3))
     loss = losses.log_loss(self._labels, tf_predictions, tf_weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
       self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     tf_weights = array_ops.zeros(shape=(2, 3))
     loss = losses.log_loss(self._labels, self._predictions, tf_weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
 class HingeLossTest(test.TestCase):
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[-1.0], [2.1]])
       labels = constant_op.constant([0.0, 1.0])
       with self.assertRaises(ValueError):
         _ = losses.hinge_loss(labels, logits).eval()
 
   def testAllOutsideMargin(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
       labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
       loss = losses.hinge_loss(labels, logits)
       self.assertAllClose(loss.eval(), 0.0, atol=1e-3)
 
   def testSomeInsideMargin(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
       labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
       loss = losses.hinge_loss(labels, logits)
@@ -832,7 +832,7 @@
       self.assertAllClose(loss.eval(), 0.175, atol=1e-3)
 
   def testSomeMisclassified(self):
-    with self.test_session():
+    with self.cached_session():
       logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
       labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
       loss = losses.hinge_loss(labels, logits)
@@ -844,14 +844,14 @@
 class HuberLossTest(test.TestCase):
 
   def testIncompatibleShapes(self):
-    with self.test_session():
+    with self.cached_session():
       predictions = constant_op.constant([[-1.0], [2.1]])
       labels = constant_op.constant([0.0, 1.0])
       with self.assertRaises(ValueError):
         _ = losses.huber_loss(labels, predictions).eval()
 
   def testAllQuadratic(self):
-    with self.test_session():
+    with self.cached_session():
       predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
       labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
       loss = losses.huber_loss(labels, predictions)
@@ -859,7 +859,7 @@
                           0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5)
 
   def testAllLinear(self):
-    with self.test_session():
+    with self.cached_session():
       predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
       labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
       loss = losses.huber_loss(labels, predictions)
@@ -867,7 +867,7 @@
                           (1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5)
 
   def testMixedQuadraticLinear(self):
-    with self.test_session():
+    with self.cached_session():
       predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0],
                                           [1.5, -1.4, -1.0, 0.0]])
       labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5],
@@ -879,7 +879,7 @@
       self.assertAllClose(loss.eval(), expected_loss, atol=1e-5)
 
   def testAllQuadraticDelta(self):
-    with self.test_session():
+    with self.cached_session():
       delta = 0.5
       predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0])
       labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
@@ -894,7 +894,7 @@
     expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean()
     expected -= 0.5 * delta**2
     loss = losses.huber_loss(labels, predictions, delta=delta)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(expected, loss.eval(), atol=1e-5)
 
 
@@ -906,13 +906,13 @@
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.mean_squared_error(
             self._predictions, self._predictions, weights=None)
 
   def testScalar(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(
           0.0,
           losses.mean_squared_error(predictions=constant_op.constant(0),
@@ -920,55 +920,55 @@
 
   def testAllCorrectNoLossWeight(self):
     loss = losses.mean_squared_error(self._predictions, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testNonZeroLoss(self):
     loss = losses.mean_squared_error(self._labels, self._predictions)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5, loss.eval(), 3)
 
   def testNonZeroLossWithPythonScalarWeight(self):
     weights = 2.3
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithScalarTensorWeight(self):
     weights = 2.3
     loss = losses.mean_squared_error(self._labels, self._predictions,
                                      constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
 
   def testNonZeroLossWithOneDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 3.4], shape=(2, 1))
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
     weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeights(self):
     weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
 
   def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
     weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(18.0, loss.eval(), 3)
 
   def testLossWithSampleSpecificWeightsAllZero(self):
     weights = array_ops.zeros((2, 3))
     loss = losses.mean_squared_error(self._labels, self._predictions, weights)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
@@ -994,7 +994,7 @@
     self._expected_losses = np.divide(total, 3.0)
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.mean_pairwise_squared_error(
             predictions=constant_op.constant(self._labels),
@@ -1003,7 +1003,7 @@
 
   def _test_valid_weights(
       self, labels, predictions, expected_loss, weights=1.0):
-    with self.test_session():
+    with self.cached_session():
       static_inputs_op = losses.mean_pairwise_squared_error(
           predictions=predictions, labels=labels, weights=weights)
       self.assertAlmostEqual(expected_loss, static_inputs_op.eval(), places=3)
@@ -1054,7 +1054,7 @@
 
       init_op = variables.global_variables_initializer()
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(init_op)
         for grad, _ in gradients_to_variables:
           np_grad = sess.run(grad)
@@ -1073,7 +1073,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         weights=constant_op.constant(weights))
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(weights * np.sum(self._expected_losses),
                              loss.eval(), 3)
 
@@ -1122,7 +1122,7 @@
         predictions=predictions_placeholder,
         labels=labels_placeholder,
         weights=weights_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
         dynamic_inputs_op.eval(feed_dict={
             predictions_placeholder: predictions,
@@ -1191,7 +1191,7 @@
           labels=array_ops.concat([labels0, labels1], 0),
           predictions=array_ops.concat([predictions0, predictions1], 0))
 
-      with self.test_session() as session:
+      with self.cached_session() as session:
         loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
 
         self.assertTrue(loss0 > 0)
@@ -1216,7 +1216,7 @@
                                [0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3))
 
   def testValueErrorThrownWhenWeightIsNone(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         losses.cosine_distance(
             predictions=constant_op.constant(self._labels),
@@ -1229,7 +1229,7 @@
         predictions=constant_op.constant(self._labels),
         labels=constant_op.constant(self._labels),
         dim=2)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(0, loss.eval(), 5)
 
   def testPartiallyCorrectWithIntegerValues(self):
@@ -1237,7 +1237,7 @@
         predictions=constant_op.constant(self._predictions),
         labels=constant_op.constant(self._labels),
         dim=2)
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(1, loss.eval(), 5)
 
   def testPartiallyCorrectFloatingPointValues(self):
@@ -1255,7 +1255,7 @@
         labels, shape=(3, 1, 3), dtype=dtypes.float32)
     loss = losses.cosine_distance(tf_labels, tf_preds, dim=2)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAlmostEqual(1.0, loss.eval(), 5)
 
   def testSampleSpecificWeights(self):
@@ -1264,7 +1264,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=np.asarray((1, 0, 0)).reshape((3, 1, 1)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(1.0, loss.eval())
 
   def testMeasurementSpecificWeights(self):
@@ -1274,7 +1274,7 @@
         dim=2,
         weights=constant_op.constant(
             [1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(3.0 / 4.0, loss.eval())
 
   def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
@@ -1286,7 +1286,7 @@
         dim=2,
         weights=constant_op.constant(
             [1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
       self.assertEqual(3.0 / 4.0, loss)
 
@@ -1296,7 +1296,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=array_ops.zeros((3, 1, 1)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0, loss.eval())
 
   def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1305,7 +1305,7 @@
         labels=constant_op.constant(self._labels),
         dim=2,
         weights=array_ops.zeros((3, 2, 1)))
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(0, loss.eval())
 
 
@@ -1411,7 +1411,7 @@
       weighted_loss = losses.compute_weighted_loss(
           self._raw_losses, weights=weight)
       self.assertEqual(1, len(util.get_losses()))
-      with self.test_session():
+      with self.cached_session():
         self.assertAllClose(
             np.mean(weight * self._raw_losses), weighted_loss.eval())
 
@@ -1429,7 +1429,7 @@
       weighted_loss = losses.compute_weighted_loss(
           self._raw_losses, weights=weights_placeholder)
       self.assertEqual(1, len(util.get_losses()))
-      with self.test_session():
+      with self.cached_session():
         with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
           weighted_loss.eval(feed_dict={weights_placeholder: weights})
 
@@ -1452,7 +1452,7 @@
       weighted_loss = losses.compute_weighted_loss(
           raw_losses, weights=weights_placeholder)
       self.assertEqual(1, len(util.get_losses()))
-      with self.test_session():
+      with self.cached_session():
         with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
           weighted_loss.eval(feed_dict={weights_placeholder: weights})
 
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index dc3ea38..f71857a 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -42,12 +42,12 @@
 
   def _testRoll(self, np_input, shift, axis):
     expected_roll = np.roll(np_input, shift, axis)
-    with self.test_session():
+    with self.cached_session():
       roll = manip_ops.roll(np_input, shift, axis)
       self.assertAllEqual(roll.eval(), expected_roll)
 
   def _testGradient(self, np_input, shift, axis):
-    with self.test_session():
+    with self.cached_session():
       inx = constant_op.constant(np_input.tolist())
       xs = list(np_input.shape)
       y = manip_ops.roll(inx, shift, axis)
@@ -94,7 +94,7 @@
     self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
     self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
     # Make sure negative axis should be 0 <= axis + dims < dims
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "is out of range"):
         manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
@@ -111,7 +111,7 @@
     tensor = array_ops.placeholder(dtype=dtypes.int32)
     shift = 1
     axis = 0
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "input must be 1-D or higher"):
         manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@@ -127,7 +127,7 @@
     tensor = [[1, 2], [3, 4]]
     shift = 1
     axis = array_ops.placeholder(dtype=dtypes.int32)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "axis must be a scalar or a 1-D vector"):
         manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@@ -143,7 +143,7 @@
     tensor = [[1, 2], [3, 4]]
     shift = array_ops.placeholder(dtype=dtypes.int32)
     axis = 1
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "shift must be a scalar or a 1-D vector"):
         manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@@ -158,7 +158,7 @@
     tensor = [[1, 2], [3, 4]]
     shift = array_ops.placeholder(dtype=dtypes.int32)
     axis = [0, 1]
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "shift and axis must have the same size"):
         manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@@ -167,7 +167,7 @@
     tensor = [1, 2]
     shift = 1
     axis = 1
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "is out of range"):
         manip_ops.roll(tensor, shift, axis).eval()
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index b167278..309da8f 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -206,7 +206,7 @@
     b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]])
     c = infix_matmul(a, b)
     d = math_ops.matmul(a, b)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(c.eval(), d.eval())
 
 
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index f41967f..720ba80 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -114,7 +114,7 @@
 
   def testNotInvertible(self):
     # The input should be invertible.
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Input is not invertible."):
         # All rows of the matrix below add to zero.
         tensor3 = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index 3328839..dd01ba1 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -143,7 +143,7 @@
   def testNonSquareMatrix(self):
     # A non-square matrix should cause an error.
     matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         self._verifySolve(matrix, matrix)
       with self.assertRaises(ValueError):
@@ -154,7 +154,7 @@
     # right-hand sides.
     matrix = np.array([[1., 0.], [0., 1.]])
     rhs = np.array([[1., 0.]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         self._verifySolve(matrix, rhs)
       with self.assertRaises(ValueError):
@@ -164,7 +164,7 @@
     # The input should be invertible.
     # The matrix is singular because it has a zero on the diagonal.
     singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Input matrix is not invertible."):
         self._verifySolve(singular_matrix, singular_matrix)
       with self.assertRaisesOpError("Input matrix is not invertible."):
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index 5565348..5dcdb9e 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -192,7 +192,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -209,7 +209,7 @@
       self.assertAlmostEqual(1.65, sess.run(mean), 5)
 
   def testUpdateOpsReturnsCurrentValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -253,7 +253,7 @@
         metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
         metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
     expected = np.mean(values)
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       for mean_result in mean_results:
         mean, update_op = mean_result
@@ -266,7 +266,7 @@
         np.sum(np.multiply(weights, np.ones_like(values)))
     )
     mean, update_op = metrics.mean(values, weights=weights)
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       self.assertAlmostEqual(expected, update_op.eval(), places=5)
       self.assertAlmostEqual(expected, mean.eval(), places=5)
@@ -330,7 +330,7 @@
 
       # Dynamic shapes.
       with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
-        with self.test_session():
+        with self.cached_session():
           _, update_op = metrics.mean(values_placeholder, invalid_weight)
           variables.local_variables_initializer().run()
           update_op.eval(feed_dict={values_placeholder: values})
@@ -359,7 +359,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -376,7 +376,7 @@
       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
 
   def testMultiDimensional(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
       _enqueue_vector(
@@ -397,7 +397,7 @@
       self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
 
   def testUpdateOpsReturnsCurrentValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
       _enqueue_vector(sess, values_queue, [0, 1])
@@ -418,7 +418,7 @@
       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
 
   def testBinaryWeighted1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -445,7 +445,7 @@
       self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
 
   def testWeighted1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -472,7 +472,7 @@
       self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5)
 
   def testWeighted2d_1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -499,7 +499,7 @@
       self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
 
   def testWeighted2d_2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -575,7 +575,7 @@
         (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
     accuracy, update_op = metrics.accuracy(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -588,7 +588,7 @@
         self.assertEqual(initial_accuracy, accuracy.eval())
 
   def testMultipleUpdates(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -618,7 +618,7 @@
   def testEffectivelyEquivalentSizes(self):
     predictions = array_ops.ones((40, 1))
     labels = array_ops.ones((40,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.accuracy(labels, predictions)
 
       sess.run(variables.local_variables_initializer())
@@ -628,7 +628,7 @@
   def testEffectivelyEquivalentSizesWithScalarWeight(self):
     predictions = array_ops.ones((40, 1))
     labels = array_ops.ones((40,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
 
       sess.run(variables.local_variables_initializer())
@@ -642,7 +642,7 @@
     weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
                                     1)  # shape 3, 1
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.accuracy(labels, predictions, weights)
 
       sess.run(variables.local_variables_initializer())
@@ -662,7 +662,7 @@
         dtype=dtypes_lib.int32, name='weights')
     feed_dict = {weights_placeholder: weights}
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       accuracy, update_op = metrics.accuracy(labels, predictions,
                                              weights_placeholder)
 
@@ -674,7 +674,7 @@
       self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
 
   def testMultipleUpdatesWithWeightedValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -746,7 +746,7 @@
         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
     precision, update_op = metrics.precision(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -765,7 +765,7 @@
     labels = constant_op.constant(inputs)
     precision, update_op = metrics.precision(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1, sess.run(update_op))
       self.assertAlmostEqual(1, precision.eval())
@@ -778,7 +778,7 @@
           constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
       precision, update_op = metrics.precision(labels, predictions)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertAlmostEqual(0.5, update_op.eval())
         self.assertAlmostEqual(0.5, precision.eval())
@@ -789,7 +789,7 @@
     precision, update_op = metrics.precision(
         labels, predictions, weights=constant_op.constant([[2], [5]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 2.0 + 5.0
       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -806,7 +806,7 @@
     }
     precision, update_op = metrics.precision(labels, predictions, weights=2)
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 2.0 + 2.0
       weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
@@ -826,7 +826,7 @@
     precision, update_op = metrics.precision(
         labels, predictions, weights=constant_op.constant([[2], [5]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 2.0 + 5.0
       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -844,7 +844,7 @@
         predictions,
         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 3.0 + 4.0
       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -864,7 +864,7 @@
         predictions,
         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
 
-    with self.test_session():
+    with self.cached_session():
       variables.local_variables_initializer().run()
       weighted_tp = 3.0 + 4.0
       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -881,7 +881,7 @@
     labels = constant_op.constant(1 - inputs)
     precision, update_op = metrics.precision(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertAlmostEqual(0, precision.eval())
@@ -891,7 +891,7 @@
     labels = constant_op.constant([0, 0, 0, 0])
     precision, update_op = metrics.precision(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0.0, precision.eval())
@@ -933,7 +933,7 @@
         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
     recall, update_op = metrics.recall(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -952,7 +952,7 @@
     labels = constant_op.constant(np_inputs)
     recall, update_op = metrics.recall(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(1, recall.eval())
@@ -965,7 +965,7 @@
           constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
       recall, update_op = metrics.recall(labels, predictions)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertAlmostEqual(0.5, update_op.eval())
         self.assertAlmostEqual(0.5, recall.eval())
@@ -976,7 +976,7 @@
     weights = constant_op.constant([[2], [5]])
     recall, update_op = metrics.recall(labels, predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_tp = 2.0 + 5.0
       weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -990,7 +990,7 @@
     weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
     recall, update_op = metrics.recall(labels, predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       weighted_tp = 3.0 + 1.0
       weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1005,7 +1005,7 @@
     labels = constant_op.constant(1 - np_inputs)
     recall, update_op = metrics.recall(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, recall.eval())
@@ -1015,7 +1015,7 @@
     labels = array_ops.zeros((1, 4))
     recall, update_op = metrics.recall(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       sess.run(update_op)
       self.assertEqual(0, recall.eval())
@@ -1055,7 +1055,7 @@
         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
     auc, update_op = metrics.auc(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1073,7 +1073,7 @@
   def allCorrectAsExpected(self, curve):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       auc, update_op = metrics.auc(labels, predictions, curve=curve)
@@ -1084,7 +1084,7 @@
       self.assertEqual(1, auc.eval())
 
   def testSomeCorrect_multipleLabelDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for label_dtype in (
           dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
         predictions = constant_op.constant(
@@ -1099,7 +1099,7 @@
         self.assertAlmostEqual(0.5, auc.eval())
 
   def testWeighted1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1112,7 +1112,7 @@
       self.assertAlmostEqual(0.5, auc.eval(), 5)
 
   def testWeighted2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1127,7 +1127,7 @@
   # Regarding the AUC-PR tests: note that the preferred method when
   # calculating AUC-PR is summation_method='careful_interpolation'.
   def testCorrectAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1141,7 +1141,7 @@
       self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
 
   def testCorrectAnotherAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
           shape=(1, 7),
@@ -1157,7 +1157,7 @@
       self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
 
   def testThirdCorrectAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
           shape=(1, 7),
@@ -1173,7 +1173,7 @@
       self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
 
   def testIncorrectAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1186,7 +1186,7 @@
       self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
 
   def testAnotherIncorrectAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
           shape=(1, 7),
@@ -1201,7 +1201,7 @@
       self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
 
   def testThirdIncorrectAUCPRSpecialCase(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
           shape=(1, 7),
@@ -1218,7 +1218,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       auc, update_op = metrics.auc(labels, predictions)
@@ -1229,7 +1229,7 @@
       self.assertAlmostEqual(0, auc.eval())
 
   def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       auc, update_op = metrics.auc(labels, predictions)
@@ -1240,7 +1240,7 @@
       self.assertAlmostEqual(1, auc.eval(), 6)
 
   def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
       labels = array_ops.ones([4])
       auc, update_op = metrics.auc(labels, predictions, curve='PR')
@@ -1301,7 +1301,7 @@
         scale=1.0, size=num_samples)):
       expected_auc = self.np_auc(predictions, labels, weights)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         enqueue_ops = [[] for i in range(num_batches)]
         tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
         tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1370,7 +1370,7 @@
     specificity, update_op = metrics.specificity_at_sensitivity(
         labels, predictions, sensitivity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1390,7 +1390,7 @@
     specificity, update_op = metrics.specificity_at_sensitivity(
         labels, predictions, sensitivity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, specificity.eval())
@@ -1405,7 +1405,7 @@
     specificity, update_op = metrics.specificity_at_sensitivity(
         labels, predictions, sensitivity=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1.0, sess.run(update_op))
       self.assertAlmostEqual(1.0, specificity.eval())
@@ -1420,7 +1420,7 @@
     specificity, update_op = metrics.specificity_at_sensitivity(
         labels, predictions, sensitivity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1439,7 +1439,7 @@
       specificity, update_op = metrics.specificity_at_sensitivity(
           labels, predictions, weights=weights, sensitivity=0.4)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
 
         self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1457,7 +1457,7 @@
     specificity, update_op = metrics.specificity_at_sensitivity(
         labels, predictions, weights=weights, sensitivity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -1507,7 +1507,7 @@
     sensitivity, update_op = metrics.sensitivity_at_specificity(
         labels, predictions, specificity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -1527,7 +1527,7 @@
     specificity, update_op = metrics.sensitivity_at_specificity(
         labels, predictions, specificity=0.7)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1, sess.run(update_op))
       self.assertEqual(1, specificity.eval())
@@ -1542,7 +1542,7 @@
     specificity, update_op = metrics.sensitivity_at_specificity(
         labels, predictions, specificity=0.8)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.8, sess.run(update_op))
       self.assertAlmostEqual(0.8, specificity.eval())
@@ -1557,7 +1557,7 @@
     specificity, update_op = metrics.sensitivity_at_specificity(
         labels, predictions, specificity=0.4)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(0.6, sess.run(update_op))
       self.assertAlmostEqual(0.6, specificity.eval())
@@ -1576,7 +1576,7 @@
       specificity, update_op = metrics.sensitivity_at_specificity(
           labels, predictions, weights=weights, specificity=0.4)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(variables.local_variables_initializer())
         self.assertAlmostEqual(0.675, sess.run(update_op))
         self.assertAlmostEqual(0.675, specificity.eval())
@@ -1638,7 +1638,7 @@
                                                     thresholds)
     rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates, then verify idempotency.
@@ -1654,7 +1654,7 @@
   def testAllCorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(inputs)
       thresholds = [0.5]
@@ -1670,7 +1670,7 @@
       self.assertEqual(1, rec.eval())
 
   def testSomeCorrect_multipleLabelDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for label_dtype in (
           dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
         predictions = constant_op.constant(
@@ -1692,7 +1692,7 @@
   def testAllIncorrect(self):
     inputs = np.random.randint(0, 2, size=(100, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
       thresholds = [0.5]
@@ -1708,7 +1708,7 @@
       self.assertAlmostEqual(0, rec.eval())
 
   def testWeights1d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1738,7 +1738,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
 
   def testWeights2d(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1768,7 +1768,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
 
   def testExtremeThresholds(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -1792,7 +1792,7 @@
       self.assertAlmostEqual(0.0, rec_high.eval())
 
   def testZeroLabelsPredictions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
       labels = array_ops.zeros([4])
       thresholds = [0.5]
@@ -1842,7 +1842,7 @@
     labels = labels.astype(np.float32)
     predictions = predictions.astype(np.float32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Reshape the data so its easy to queue up:
       predictions_batches = predictions.reshape((batch_size, num_batches))
       labels_batches = labels.reshape((batch_size, num_batches))
@@ -2801,7 +2801,7 @@
     labels = random_ops.random_normal((10, 3), seed=2)
     error, update_op = metrics.mean_absolute_error(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2822,7 +2822,7 @@
 
     error, update_op = metrics.mean_absolute_error(labels, predictions, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(3, sess.run(update_op))
       self.assertEqual(3, error.eval())
@@ -2866,7 +2866,7 @@
     error, update_op = metrics.mean_relative_error(labels, predictions,
                                                    normalizer)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2891,7 +2891,7 @@
     error, update_op = metrics.mean_relative_error(
         labels, predictions, normalizer=labels)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(expected_error, sess.run(update_op))
       self.assertEqual(expected_error, error.eval())
@@ -2907,7 +2907,7 @@
     error, update_op = metrics.mean_relative_error(
         labels, predictions, normalizer=array_ops.zeros_like(labels))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0.0, sess.run(update_op))
       self.assertEqual(0.0, error.eval())
@@ -2945,7 +2945,7 @@
     labels = random_ops.random_normal((10, 3), seed=2)
     error, update_op = metrics.mean_squared_error(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -2963,7 +2963,7 @@
 
     error, update_op = metrics.mean_squared_error(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -2976,7 +2976,7 @@
 
     error, update_op = metrics.mean_squared_error(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(6, sess.run(update_op))
       self.assertEqual(6, error.eval())
@@ -2990,13 +2990,13 @@
 
     error, update_op = metrics.mean_squared_error(labels, predictions, weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(13, sess.run(update_op))
       self.assertEqual(13, error.eval())
 
   def testMultipleBatchesOfSizeOne(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3020,7 +3020,7 @@
       self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
 
   def testMetricsComputedConcurrently(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates one set of predictions.
       preds_queue0 = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3063,7 +3063,7 @@
       self.assertAlmostEqual(79.0 / 6, mse1, 5)
 
   def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3122,7 +3122,7 @@
     labels = random_ops.random_normal((10, 3), seed=2)
     error, update_op = metrics.root_mean_squared_error(labels, predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3135,7 +3135,7 @@
         self.assertEqual(initial_error, error.eval())
 
   def testSingleUpdateZeroError(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           0.0, shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -3148,7 +3148,7 @@
       self.assertEqual(0, rmse.eval())
 
   def testSingleUpdateWithError(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -3161,7 +3161,7 @@
       self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
 
   def testSingleUpdateWithErrorAndWeights(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       predictions = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
       labels = constant_op.constant(
@@ -3220,7 +3220,7 @@
     labels = random_ops.random_normal((10, 3), seed=2)
     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3242,7 +3242,7 @@
 
     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -3258,7 +3258,7 @@
 
     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1, sess.run(update_op), 5)
       self.assertAlmostEqual(1, error.eval(), 5)
@@ -3279,7 +3279,7 @@
         np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAlmostEqual(1.0, sess.run(update_op), 5)
       self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -3298,7 +3298,7 @@
     error, update_op = metrics.mean_cosine_distance(
         labels, predictions, dim=2, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(0, sess.run(update_op))
       self.assertEqual(0, error.eval())
@@ -3317,7 +3317,7 @@
     error, update_op = metrics.mean_cosine_distance(
         labels, predictions, dim=2, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertEqual(1.5, update_op.eval())
       self.assertEqual(1.5, error.eval())
@@ -3352,7 +3352,7 @@
     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
 
   def testOneUpdate(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
 
@@ -3369,7 +3369,7 @@
       self.assertAlmostEqual(0.0, pcnt2, 5)
 
   def testSomePresentOneUpdate(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values = constant_op.constant(
           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
       weights = constant_op.constant(
@@ -3445,7 +3445,7 @@
     mean_iou, update_op = metrics.mean_iou(
         labels, predictions, num_classes=num_classes)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3459,7 +3459,7 @@
 
   def testMultipleUpdates(self):
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3490,7 +3490,7 @@
 
   def testMultipleUpdatesWithWeights(self):
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3538,7 +3538,7 @@
     # one class, and thus there is one row and one column with
     # zero entries in the confusion matrix.
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       # There is no prediction for class 2.
       preds_queue = data_flow_ops.FIFOQueue(
@@ -3585,7 +3585,7 @@
         ],
         0)
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       confusion_matrix = update_op.eval()
@@ -3597,7 +3597,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.zeros([40])
     num_classes = 1
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       self.assertEqual(40, update_op.eval()[0])
@@ -3607,7 +3607,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.ones([40])
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
@@ -3637,7 +3637,7 @@
                         0, shape=[1])
         ],
         0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(
           labels, predictions, num_classes, weights=weights)
       sess.run(variables.local_variables_initializer())
@@ -3657,7 +3657,7 @@
         [[0, 0, 2, 1, 1, 1],
          [1, 1, 2, 0, 0, 0]]])
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
@@ -3669,7 +3669,7 @@
     labels = constant_op.constant([0])
     predictions = constant_op.constant([0])
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
@@ -3687,7 +3687,7 @@
         [[0, 0, 0, 1, 1, 1],
          [1, 1, 1, 0, 0, 0]]])
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
@@ -3751,7 +3751,7 @@
     mean_accuracy, update_op = metrics.mean_per_class_accuracy(
         labels, predictions, num_classes=num_classes)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
 
       # Run several updates.
@@ -3764,7 +3764,7 @@
         self.assertEqual(initial_mean_accuracy, mean_accuracy.eval())
 
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3796,7 +3796,7 @@
 
   def testMultipleUpdatesWithWeights(self):
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       preds_queue = data_flow_ops.FIFOQueue(
           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3844,7 +3844,7 @@
     # one class, and thus there is one row and one column with
     # zero entries in the confusion matrix.
     num_classes = 3
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Create the queue that populates the predictions.
       # There is no prediction for class 2.
       preds_queue = data_flow_ops.FIFOQueue(
@@ -3880,7 +3880,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.zeros([40])
     num_classes = 1
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
           labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
@@ -3891,7 +3891,7 @@
     predictions = array_ops.zeros([40])
     labels = array_ops.ones([40])
     num_classes = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
           labels, predictions, num_classes)
       sess.run(variables.local_variables_initializer())
@@ -3910,7 +3910,7 @@
         constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
         constant_op.constant(0, shape=[1])
     ], 0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
           labels, predictions, num_classes, weights=weights)
       sess.run(variables.local_variables_initializer())
@@ -3944,7 +3944,7 @@
     tn, tn_update_op = metrics.false_negatives(
         labels=labels, predictions=predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(3., tn_update_op.eval())
@@ -3963,7 +3963,7 @@
     tn, tn_update_op = metrics.false_negatives(
         labels=labels, predictions=predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(5., tn_update_op.eval())
@@ -3993,7 +3993,7 @@
     fn, fn_update_op = metrics.false_negatives_at_thresholds(
         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), fn.eval())
       self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -4012,7 +4012,7 @@
         weights=((3.0,), (5.0,), (7.0,)),
         thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
       self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -4043,7 +4043,7 @@
     tn, tn_update_op = metrics.false_positives(
         labels=labels, predictions=predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(7., tn_update_op.eval())
@@ -4062,7 +4062,7 @@
     tn, tn_update_op = metrics.false_positives(
         labels=labels, predictions=predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(14., tn_update_op.eval())
@@ -4092,7 +4092,7 @@
     fp, fp_update_op = metrics.false_positives_at_thresholds(
         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), fp.eval())
       self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -4113,7 +4113,7 @@
                  (19.0, 23.0, 29.0, 31.0)),
         thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
       self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -4144,7 +4144,7 @@
     tn, tn_update_op = metrics.true_negatives(
         labels=labels, predictions=predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(3., tn_update_op.eval())
@@ -4163,7 +4163,7 @@
     tn, tn_update_op = metrics.true_negatives(
         labels=labels, predictions=predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(4., tn_update_op.eval())
@@ -4193,7 +4193,7 @@
     tn, tn_update_op = metrics.true_negatives_at_thresholds(
         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), tn.eval())
       self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -4212,7 +4212,7 @@
         weights=((0.0, 2.0, 3.0, 5.0),),
         thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
       self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -4243,7 +4243,7 @@
     tn, tn_update_op = metrics.true_positives(
         labels=labels, predictions=predictions)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(7., tn_update_op.eval())
@@ -4262,7 +4262,7 @@
     tn, tn_update_op = metrics.true_positives(
         labels=labels, predictions=predictions, weights=weights)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllClose(0., tn.eval())
       self.assertAllClose(12., tn_update_op.eval())
@@ -4292,7 +4292,7 @@
     tp, tp_update_op = metrics.true_positives_at_thresholds(
         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0, 0, 0), tp.eval())
       self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -4309,7 +4309,7 @@
         predictions=predictions, labels=labels, weights=37.0,
         thresholds=[0.15, 0.5, 0.85])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(variables.local_variables_initializer())
       self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
       self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index 944de21..e415d78 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -188,7 +188,7 @@
                       mode="SYMMETRIC").eval()
 
   def testInvalid(self):
-    with self.test_session():
+    with self.cached_session():
       x = [[1, 2, 3], [4, 5, 6]]
       with self.assertRaisesRegexp(ValueError, "Unknown padding mode"):
         array_ops.pad(x, [[1, 0], [2, 1]], mode="weird").eval()
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index d8c3f982..95f3dcc 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -95,13 +95,13 @@
       """, q.queue_ref.op.node_def)
 
   def testEnqueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       enqueue_op = q.enqueue((10.0,))
       enqueue_op.run()
 
   def testEnqueueWithShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(
           10, dtypes_lib.float32, shapes=((3, 2),))
       enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
@@ -111,14 +111,14 @@
       self.assertEqual(1, q.size().eval())
 
   def testEnqueueManyWithShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(
           10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
       q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
       self.assertEqual(4, q.size().eval())
 
   def testParallelEnqueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -144,7 +144,7 @@
       self.assertItemsEqual(elems, results)
 
   def testParallelDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -168,7 +168,7 @@
       self.assertItemsEqual(elems, results)
 
   def testDequeue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -182,7 +182,7 @@
         self.assertEqual([elems[i]], vals)
 
   def testEnqueueAndBlockingDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0]
       enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -212,7 +212,7 @@
         self.assertEqual([elem], result)
 
   def testMultiEnqueueAndDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10,
                                          (dtypes_lib.int32, dtypes_lib.float32),
                                          ((), ()))
@@ -230,12 +230,12 @@
         self.assertEqual([y], y_val)
 
   def testQueueSizeEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       self.assertEqual([0], q.size().eval())
 
   def testQueueSizeAfterEnqueueAndDequeue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       enqueue_op = q.enqueue((10.0,))
       dequeued_t = q.dequeue()
@@ -248,7 +248,7 @@
       self.assertEqual(0, size.eval())
 
   def testEnqueueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -261,7 +261,7 @@
         self.assertEqual([elems[i % 4]], vals)
 
   def testEmptyEnqueueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, (
           (None, None),))
       empty_t = constant_op.constant(
@@ -274,7 +274,7 @@
       self.assertEqual([0], size_t.eval())
 
   def testEmptyDequeueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, shapes=((),))
       enqueue_op = q.enqueue((10.0,))
       dequeued_t = q.dequeue_many(0)
@@ -284,7 +284,7 @@
       self.assertEqual([], dequeued_t.eval().tolist())
 
   def testEmptyDequeueManyWithDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(
           10, dtypes_lib.float32, shapes=((None,),))
       enqueue_op = q.enqueue(([10.0],))
@@ -295,7 +295,7 @@
       self.assertEqual([], dequeued_t.eval().tolist())
 
   def testEmptyDequeueUpToWithDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(
           10, dtypes_lib.float32, shapes=((None,),))
       enqueue_op = q.enqueue(([10.0],))
@@ -306,7 +306,7 @@
       self.assertEqual([], dequeued_t.eval().tolist())
 
   def testConstructPaddingFIFOQueueWithNoShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError,
           r"When providing partial shapes, a list of shapes must be provided."):
@@ -314,7 +314,7 @@
                                        None).queue_ref.eval()
 
   def testMultiEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10,
                                          (dtypes_lib.float32, dtypes_lib.int32),
                                          ((), (2,)))
@@ -332,7 +332,7 @@
         self.assertAllEqual(int_elems[i % 4], int_val)
 
   def testMultiEnqueueManyWithPartiallyKnownShapes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
       float_elems = [10.0, 20.0, 30.0, 40.0]
@@ -349,7 +349,7 @@
         self.assertAllEqual(int_elems[i % 4], int_val)
 
   def testDequeueMany(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -361,7 +361,7 @@
       self.assertAllEqual(elems[4:8], dequeued_t.eval())
 
   def testDequeueUpToNoBlocking(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@
       self.assertAllEqual(elems[4:8], dequeued_t.eval())
 
   def testMultiDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
       float_elems = [
@@ -404,7 +404,7 @@
       self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
 
   def testMultiDequeueManyWithPartiallyKnownShapes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(
           10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
       float_elems = [
@@ -443,7 +443,7 @@
               dequeued_single_t[1].get_shape()))
 
   def testMultiDequeueManyWithPartiallyKnownShapesAndVariableSizeInput(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(
           10, (dtypes_lib.string, dtypes_lib.int32),
           shapes=((None,), (1, None)))
@@ -484,7 +484,7 @@
               dequeued_single_t[1].get_shape()))
 
   def testMultiDequeueUpToPartiallyKnownShapesAndVariableInputNoBlocking(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(
           10, (dtypes_lib.string, dtypes_lib.int32),
           shapes=((None,), (1, None)))
@@ -525,7 +525,7 @@
               dequeued_single_t[1].get_shape()))
 
   def testHighDimension(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, ((4, 4, 4, 4),))
       elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
       enqueue_op = q.enqueue_many((elems,))
@@ -535,7 +535,7 @@
       self.assertAllEqual(dequeued_t.eval(), elems)
 
   def testPartiallyKnownHighDimension(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, (
           (4, None, 4, None),))
       elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
@@ -592,7 +592,7 @@
                       array_ops.placeholder(dtypes_lib.int32)))
 
   def testEnqueueWrongPartiallyKnownShapeAtRuntime(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # First dimension of second component is unknown, second
       # dimension must be 3.
       q = data_flow_ops.PaddingFIFOQueue(10,
@@ -607,7 +607,7 @@
                  feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
 
   def testEnqueueDequeueManyWrongPartiallyKnownShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # First dimension of second component is unknown, second
       # dimension must be 3.
       q = data_flow_ops.PaddingFIFOQueue(10,
@@ -625,7 +625,7 @@
         dequeued_t.eval()
 
   def testParallelEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
       elems = [10.0 * x for x in range(100)]
       enqueue_op = q.enqueue_many((elems,))
@@ -644,7 +644,7 @@
       self.assertItemsEqual(dequeued_t.eval(), elems * 10)
 
   def testParallelDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
       elems = [10.0 * x for x in range(1000)]
       enqueue_op = q.enqueue_many((elems,))
@@ -666,7 +666,7 @@
       self.assertItemsEqual(elems, dequeued_elems)
 
   def testParallelDequeueUpTo(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
       elems = [10.0 * x for x in range(1000)]
       enqueue_op = q.enqueue_many((elems,))
@@ -690,7 +690,7 @@
       self.assertItemsEqual(elems, dequeued_elems)
 
   def testParallelEnqueueAndDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
       initial_elements = [10.0] * 49
       q.enqueue_many((initial_elements,)).run()
@@ -723,7 +723,7 @@
         self.assertTrue(elem in (10.0, 20.0))
 
   def testMixtureOfEnqueueAndEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
       enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
       enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -759,7 +759,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testMixtureOfDequeueAndDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
       enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
       dequeued_t = q.dequeue()
@@ -793,7 +793,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -820,7 +820,7 @@
       self.assertAllEqual(elems, dequeued_elems)
 
   def testBlockingDequeueUpTo(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -847,7 +847,7 @@
       self.assertAllEqual(elems, dequeued_elems)
 
   def testDequeueManyWithTensorParameter(self):
-    with self.test_session():
+    with self.cached_session():
       # Define a first queue that contains integer counts.
       dequeue_counts = [random.randint(1, 10) for _ in range(100)]
       count_q = data_flow_ops.PaddingFIFOQueue(100, dtypes_lib.int32, ((),))
@@ -872,7 +872,7 @@
       self.assertEqual(elems, dequeued_elems)
 
   def testDequeueFromClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -890,7 +890,7 @@
         dequeued_t.eval()
 
   def testBlockingDequeueFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -916,7 +916,7 @@
       dequeue_thread.join()
 
   def testDequeueUpToFromClosedQueueReturnsRemainder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -938,7 +938,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       close_op = q.close()
       dequeued_t = q.dequeue()
@@ -958,7 +958,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueManyFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -983,7 +983,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueManyButNotAllFromClosedQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1008,7 +1008,7 @@
       dequeue_thread.join()
 
   def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1045,7 +1045,7 @@
       close_thread.join()
 
   def testClosedBlockingDequeueManyRestoresPartialBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, (dtypes_lib.float32,
                                              dtypes_lib.float32), ((), ()))
       elems_a = [1.0, 2.0, 3.0]
@@ -1078,7 +1078,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingDequeueManyFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       close_op = q.close()
       dequeued_t = q.dequeue_many(4)
@@ -1098,7 +1098,7 @@
       dequeue_thread.join()
 
   def testBlockingDequeueUpToFromClosedEmptyQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       close_op = q.close()
       dequeued_t = q.dequeue_up_to(4)
@@ -1118,7 +1118,7 @@
       dequeue_thread.join()
 
   def testEnqueueToClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       enqueue_op = q.enqueue((10.0,))
       close_op = q.close()
@@ -1131,7 +1131,7 @@
         enqueue_op.run()
 
   def testEnqueueManyToClosedQueue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1145,7 +1145,7 @@
         enqueue_op.run()
 
   def testBlockingEnqueueToFullQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1168,7 +1168,7 @@
       thread.join()
 
   def testBlockingEnqueueManyToFullQueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1195,7 +1195,7 @@
       thread.join()
 
   def testBlockingEnqueueBeforeClose(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0, 40.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1232,7 +1232,7 @@
       self.assertEqual(0, q.size().eval())
 
   def testBlockingEnqueueManyBeforeClose(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
       elems = [10.0, 20.0, 30.0]
       enqueue_op = q.enqueue_many((elems,))
@@ -1265,7 +1265,7 @@
         self.assertEqual(elem, dequeued_t.eval())
 
   def testDoesNotLoseValue(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PaddingFIFOQueue(1, dtypes_lib.float32, ((),))
       enqueue_op = q.enqueue((10.0,))
       size_t = q.size()
@@ -1275,7 +1275,7 @@
         self.assertEqual(size_t.eval(), [1])
 
   def testSharedQueueSameSession(self):
-    with self.test_session():
+    with self.cached_session():
       q1 = data_flow_ops.PaddingFIFOQueue(
           1, dtypes_lib.float32, ((),), shared_name="shared_queue")
       q1.enqueue((10.0,)).run()
@@ -1305,7 +1305,7 @@
       self.assertEqual(q2_size_t.eval(), [0])
 
   def testIncompatibleSharedQueueErrors(self):
-    with self.test_session():
+    with self.cached_session():
       q_a_1 = data_flow_ops.PaddingFIFOQueue(
           10, dtypes_lib.float32, ((),), shared_name="q_a")
       q_a_2 = data_flow_ops.PaddingFIFOQueue(
@@ -1356,7 +1356,7 @@
         q_f_2.queue_ref.op.run()
 
   def testSelectQueue(self):
-    with self.test_session():
+    with self.cached_session():
       num_queues = 10
       qlist = list()
       for _ in xrange(num_queues):
@@ -1370,7 +1370,7 @@
         self.assertEqual(q.dequeue().eval(), 10.0)
 
   def testSelectQueueOutOfRange(self):
-    with self.test_session():
+    with self.cached_session():
       q1 = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
       q2 = data_flow_ops.PaddingFIFOQueue(15, dtypes_lib.float32, ((),))
       enq_q = data_flow_ops.PaddingFIFOQueue.from_list(3, [q1, q2])
@@ -1394,7 +1394,7 @@
       sess.run(enqueue_many_op)
 
   def testResetOfBlockingOperation(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
       dequeue_op = q_empty.dequeue()
       dequeue_many_op = q_empty.dequeue_many(1)
@@ -1422,7 +1422,7 @@
         t.join()
 
   def testBigEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.int32, ((),))
       elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
       enq = q.enqueue_many((elem,))
@@ -1467,7 +1467,7 @@
       self.assertAllEqual(elem, results)
 
   def testBigDequeueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(2, dtypes_lib.int32, ((),))
       elem = np.arange(4, dtype=np.int32)
       enq_list = [q.enqueue((e,)) for e in elem]
@@ -1493,7 +1493,7 @@
       self.assertAllEqual(elem, results)
 
   def testDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dtypes = [
           dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
           dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
diff --git a/tensorflow/python/kernel_tests/parse_single_example_op_test.py b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
index bf4c89b..a84895a 100644
--- a/tensorflow/python/kernel_tests/parse_single_example_op_test.py
+++ b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
@@ -89,7 +89,7 @@
 class ParseExampleTest(test.TestCase):
 
   def _test(self, kwargs, expected_values=None, expected_err=None):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
@@ -844,7 +844,7 @@
 class ParseSingleExampleTest(test.TestCase):
 
   def _test(self, kwargs, expected_values=None, expected_err=None):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 7dff450..71d8b60 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -89,7 +89,7 @@
 class ParseExampleTest(test.TestCase):
 
   def _test(self, kwargs, expected_values=None, expected_err=None):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
@@ -937,7 +937,7 @@
 class ParseSingleExampleTest(test.TestCase):
 
   def _test(self, kwargs, expected_values=None, expected_err=None):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
@@ -1054,7 +1054,7 @@
     expected_feat_list_values = expected_feat_list_values or {}
     expected_length_values = expected_length_values or {}
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if expected_err:
         with self.assertRaisesWithPredicateMatch(expected_err[0],
                                                  expected_err[1]):
@@ -1606,7 +1606,7 @@
 class DecodeJSONExampleTest(test.TestCase):
 
   def _testRoundTrip(self, examples):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       examples = np.array(examples, dtype=np.object)
 
       json_tensor = constant_op.constant(
@@ -1696,7 +1696,7 @@
     ])
 
   def testInvalidSyntax(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       json_tensor = constant_op.constant(["{]"])
       binary_tensor = parsing_ops.decode_json_example(json_tensor)
       with self.assertRaisesOpError("Error while parsing JSON"):
@@ -1706,7 +1706,7 @@
 class ParseTensorOpTest(test.TestCase):
 
   def testToFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       expected = np.random.rand(3, 4, 5).astype(np.float32)
       tensor_proto = tensor_util.make_tensor_proto(expected)
 
@@ -1719,7 +1719,7 @@
       self.assertAllEqual(expected, result)
 
   def testToUint8(self):
-    with self.test_session():
+    with self.cached_session():
       expected = np.random.rand(3, 4, 5).astype(np.uint8)
       tensor_proto = tensor_util.make_tensor_proto(expected)
 
@@ -1732,7 +1732,7 @@
       self.assertAllEqual(expected, result)
 
   def testTypeMismatch(self):
-    with self.test_session():
+    with self.cached_session():
       expected = np.random.rand(3, 4, 5).astype(np.uint8)
       tensor_proto = tensor_util.make_tensor_proto(expected)
 
@@ -1745,7 +1745,7 @@
         tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
 
   def testInvalidInput(self):
-    with self.test_session():
+    with self.cached_session():
       serialized = array_ops.placeholder(dtypes.string)
       tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)
 
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 15d5702..b34d30f 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -39,7 +39,7 @@
 class PartitionerCreatorsTest(test.TestCase):
 
   def testFixedSizePartitioner(self):
-    with self.test_session():
+    with self.cached_session():
       partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
       with variable_scope.variable_scope("root", partitioner=partitioner):
         v0 = variable_scope.get_variable(
@@ -50,7 +50,7 @@
         self.assertAllEqual(v0_part, (5, 1))
 
   def testFixedSizePartitionerInt64(self):
-    with self.test_session():
+    with self.cached_session():
       partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0)
       with variable_scope.variable_scope("root", partitioner=partitioner):
         v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20])
@@ -58,7 +58,7 @@
         self.assertEqual(len(v0_list), 4)
 
   def testResourceFixedSizePartitioner(self):
-    with self.test_session():
+    with self.cached_session():
       partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
       with variable_scope.variable_scope(
           "root", partitioner=partitioner, use_resource=True):
@@ -88,7 +88,7 @@
       self.assertAllEqual(v0_part, expected_partitions)
 
   def testVariableAxisSizePartitioner(self):
-    with self.test_session():
+    with self.cached_session():
       # Create a partitioned variable of shape (4, 8, 16, 32) type float32
       # Bytes per slice along the given axes:
 
@@ -210,7 +210,7 @@
       self.assertAllEqual(v0_part, expected_partitions)
 
   def testMinMaxVariablePartitioner(self):
-    with self.test_session():
+    with self.cached_session():
       # Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
       self._testMinMaxVariablePartitioner(
           max_partitions=100,
@@ -323,7 +323,7 @@
       self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec)
 
   def testVecConstantInit(self):
-    with self.test_session():
+    with self.cached_session():
       rnd_par = constant_op.constant([1, 2, 3, 4])
       vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
       variables.global_variables_initializer().run()
@@ -334,7 +334,7 @@
       self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
 
   def testConstantInit(self):
-    with self.test_session():
+    with self.cached_session():
       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
       vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
                                                               rnd_par)
@@ -346,7 +346,7 @@
       self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
 
   def _testNameHelper(self, use_resource=False):
-    with self.test_session():
+    with self.cached_session():
       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
       with variable_scope.variable_scope("hi", use_resource=use_resource):
         vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -363,7 +363,7 @@
       self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
       self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
     # Test same variable.
-    with self.test_session():
+    with self.cached_session():
       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
       with variable_scope.variable_scope(
           "hola", use_resource=use_resource) as vs:
@@ -383,7 +383,7 @@
       self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
       self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
     # Test name_scope
-    with self.test_session():
+    with self.cached_session():
       rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
       with ops.name_scope("ola"):
         vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -408,7 +408,7 @@
     self._testNameHelper(use_resource=True)
 
   def testRandomInitValue(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(random_ops.random_uniform([200, 40]))
       vs = partitioned_variables.create_partitioned_variables(
           rnd.get_shape(), [1, 10], rnd.initialized_value())
@@ -425,7 +425,7 @@
       ])
 
   def testRandomInitUnevenPartitions(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(
           random_ops.random_uniform([20, 43], dtype=dtypes.float64))
       var_lists = [
@@ -463,7 +463,7 @@
           self._TestSaveSpec(vs, save_specs[i])
 
   def testDegenerate(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
       vs = partitioned_variables.create_partitioned_variables(
           rnd.get_shape(), [1, 1], rnd.initialized_value())
@@ -474,7 +474,7 @@
       self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
 
   def testSliceSizeOne(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
       vs = partitioned_variables.create_partitioned_variables(
           rnd.get_shape(), [10, 1], rnd.initialized_value())
@@ -492,7 +492,7 @@
     self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
     self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
                         _IotaInitializer([4, 2]))
-    with self.test_session():
+    with self.cached_session():
       vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
                                                               _IotaInitializer)
       variables.global_variables_initializer().run()
@@ -506,7 +506,7 @@
   def testRandomInitializer(self):
     # Sanity check that the slices uses a different seed when using a random
     # initializer function.
-    with self.test_session():
+    with self.cached_session():
       var0, var1 = partitioned_variables.create_partitioned_variables(
           [20, 12], [1, 2], init_ops.random_uniform_initializer())
       variables.global_variables_initializer().run()
@@ -514,7 +514,7 @@
       self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
     # Negative test that proves that slices have the same values if
     # the random initializer uses a seed.
-    with self.test_session():
+    with self.cached_session():
       var0, var1 = partitioned_variables.create_partitioned_variables(
           [20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
       variables.global_variables_initializer().run()
@@ -522,7 +522,7 @@
       self.assertAllClose(val0, val1)
 
   def testSomeErrors(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(random_ops.random_uniform([10, 43]))
       with self.assertRaises(ValueError):
         partitioned_variables.create_partitioned_variables(
@@ -547,7 +547,7 @@
             [10, 43], [1, 50], rnd.initialized_value())
 
   def testControlDepsNone(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       c = constant_op.constant(1.0)
       with ops.control_dependencies([c]):
         # d get the control dependency.
@@ -573,7 +573,7 @@
         self.assertEqual([], op.control_inputs)
 
   def testConcat(self):
-    with self.test_session() as session:
+    with self.cached_session() as session:
       var_x = variable_scope.get_variable(
           "x",
           initializer=constant_op.constant([1., 2.]),
diff --git a/tensorflow/python/kernel_tests/priority_queue_test.py b/tensorflow/python/kernel_tests/priority_queue_test.py
index 3fb9c9c..73a9c81 100644
--- a/tensorflow/python/kernel_tests/priority_queue_test.py
+++ b/tensorflow/python/kernel_tests/priority_queue_test.py
@@ -36,7 +36,7 @@
 class PriorityQueueTest(test.TestCase):
 
   def testRoundTripInsertReadOnceSorts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
           (), ()))
       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -67,7 +67,7 @@
       self.assertEqual(missed, set())
 
   def testRoundTripInsertMultiThreadedReadOnceSorts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
           (), ()))
       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -113,7 +113,7 @@
       self.assertEqual(missed, set())
 
   def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
 
       num_threads = 40
@@ -163,7 +163,7 @@
       self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
 
   def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
 
       num_threads = 40
@@ -219,7 +219,7 @@
       self.assertAllEqual(set(dequeued), set(all_enqueued_values))
 
   def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
           (), ()))
       elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -268,7 +268,7 @@
       self.assertEqual(missed, set())
 
   def testRoundTripInsertOnceReadOnceSorts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
           (), ()))
       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
@@ -289,7 +289,7 @@
         self.assertTrue((dv0, dv1) in allowed[e])
 
   def testRoundTripInsertOnceReadManySorts(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
       q.enqueue_many((elem, elem)).run()
@@ -297,7 +297,7 @@
       self.assertAllEqual(deq_values, sorted(elem))
 
   def testRoundTripInsertOnceReadOnceLotsSorts(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
       elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
       q.enqueue_many((elem, elem)).run()
@@ -306,13 +306,13 @@
       self.assertAllEqual(deq_values, sorted(elem))
 
   def testInsertingNonInt64Fails(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (()))
       with self.assertRaises(TypeError):
         q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()
 
   def testInsertingNonScalarFails(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_priority = array_ops.placeholder(dtypes.int64)
       input_other = array_ops.placeholder(dtypes.string)
       q = data_flow_ops.PriorityQueue(2000, (dtypes.string,), (()))
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 50154a4..5f5e24b 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -61,7 +61,7 @@
     for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
                   dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
                   dtypes.int32, dtypes.int64]:
-      with self.test_session():
+      with self.cached_session():
         x = constant_op.constant(1, dtype=dtype)
         y = constant_op.constant(2, dtype=dtype)
         z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
@@ -71,7 +71,7 @@
     def sub_func(x, y):
       return x - y
     for dtype in [dtypes.complex64, dtypes.complex128]:
-      with self.test_session():
+      with self.cached_session():
         x = constant_op.constant(1 + 1j, dtype=dtype)
         y = constant_op.constant(2 - 2j, dtype=dtype)
         z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
@@ -81,21 +81,21 @@
     def and_func(x, y):
       return x and y
     dtype = dtypes.bool
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(True, dtype=dtype)
       y = constant_op.constant(False, dtype=dtype)
       z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
       self.assertEqual(z, False)
 
   def testSingleType(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(1.0, dtypes.float32)
       y = constant_op.constant(2.0, dtypes.float32)
       z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
       self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
 
   def testScalar(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(1.0, dtypes.float32)
       y = constant_op.constant(2.0, dtypes.float32)
       z = self.evaluate(
@@ -103,7 +103,7 @@
       self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
 
   def testArray(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([1.0, 2.0], dtypes.float64)
       y = constant_op.constant([2.0, 3.0], dtypes.float64)
       z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
@@ -111,14 +111,14 @@
                           np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
 
   def testComplexType(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(1 + 2j, dtypes.complex64)
       y = constant_op.constant(3 + 4j, dtypes.complex64)
       z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
       self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
 
   def testRFFT(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
 
       def rfft(x):
@@ -128,7 +128,7 @@
       self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
 
   def testPythonLiteral(self):
-    with self.test_session():
+    with self.cached_session():
 
       def literal(x):
         return 1.0 if float(x) == 0.0 else 0.0
@@ -138,7 +138,7 @@
       self.assertAllClose(y, 1.0)
 
   def testList(self):
-    with self.test_session():
+    with self.cached_session():
 
       def list_func(x):
         return [x, x + 1]
@@ -150,7 +150,7 @@
 
   def testTuple(self):
     # returns a tuple
-    with self.test_session():
+    with self.cached_session():
 
       def tuple_func(x):
         return x, x + 1
@@ -161,7 +161,7 @@
       self.assertAllClose(y, [0.0, 1.0])
 
     # returns a tuple, Tout and inp a tuple
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(0.0, dtypes.float64)
       y = self.evaluate(
           script_ops.py_func(tuple_func, (x,),
@@ -176,7 +176,7 @@
     def read_and_return_strings(x, y):
       return x + y
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([b"hello", b"hi"], dtypes.string)
       y = self.evaluate(
           script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -193,7 +193,7 @@
     def read_and_return_strings(x, y):
       return x + y
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(["hello", "hi"], dtypes.string)
       y = self.evaluate(
           script_ops.py_func(read_fixed_length_numpy_strings, [],
@@ -210,7 +210,7 @@
     def read_and_return_strings(x, y):
       return x + y
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(["hello", "hi"], dtypes.string)
       y, = script_ops.py_func(read_object_array, [],
                               [dtypes.string])
@@ -219,19 +219,19 @@
 
   def testStringPadding(self):
     correct = [b"this", b"is", b"a", b"test"]
-    with self.test_session():
+    with self.cached_session():
       s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
       self.assertAllEqual(s.eval(), correct)
 
   def testStringPaddingAreConvertedToBytes(self):
     inp = ["this", "is", "a", "test"]
     correct = [b"this", b"is", b"a", b"test"]
-    with self.test_session():
+    with self.cached_session():
       s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
       self.assertAllEqual(s.eval(), correct)
 
   def testLarge(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.zeros([1000000], dtype=np.float32)
       y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
       z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
@@ -239,12 +239,12 @@
         sess.run([y[0].op, z[0].op])
 
   def testNoInput(self):
-    with self.test_session():
+    with self.cached_session():
       x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
       self.assertAllClose(x, 42.0)
 
   def testAlias(self):
-    with self.test_session():
+    with self.cached_session():
       np_array = np.array([1.0, 2.0], dtype=np.float32)
       tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
       value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
@@ -252,7 +252,7 @@
       self.assertAllEqual(np_array, [1.0, 2.0])
 
   def testReturnUnicodeString(self):
-    with self.test_session():
+    with self.cached_session():
       correct = u"你好 世界"
 
       def unicode_string():
@@ -262,7 +262,7 @@
       self.assertEqual(z.eval(), correct.encode("utf8"))
 
   def testBadNumpyReturnType(self):
-    with self.test_session():
+    with self.cached_session():
 
       def bad():
         # Structured numpy arrays aren't supported.
@@ -275,7 +275,7 @@
         y.eval()
 
   def testBadReturnType(self):
-    with self.test_session():
+    with self.cached_session():
 
       def bad():
         # Non-string python objects aren't supported.
@@ -288,7 +288,7 @@
         z.eval()
 
   def testReturnInput(self):
-    with self.test_session():
+    with self.cached_session():
 
       def ident(x):
         return x[0]
@@ -303,7 +303,7 @@
       self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
 
   def testStateful(self):
-    # Not using self.test_session(), which disables optimization.
+    # Not using self.cached_session(), which disables optimization.
     with session_lib.Session() as sess:
       producer = iter(range(3))
       x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
@@ -312,7 +312,7 @@
       self.assertEqual(sess.run(x), 2)
 
   def testStateless(self):
-    # Not using self.test_session(), which disables optimization.
+    # Not using self.cached_session(), which disables optimization.
     with session_lib.Session() as sess:
       producer = iter(range(3))
       x, = script_ops.py_func(
@@ -331,7 +331,7 @@
     self.assertEqual(None, ops.get_gradient_function(y.op))
 
   def testCOrder(self):
-    with self.test_session():
+    with self.cached_session():
       val = [[1, 2], [3, 4]]
       x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
                               [dtypes.int64])
@@ -339,7 +339,7 @@
 
   def testParallel(self):
     # Tests that tf.py_func's can run in parallel if they release the GIL.
-    with self.test_session() as session:
+    with self.cached_session() as session:
       q = queue.Queue(1)
 
       def blocking_put():
@@ -375,7 +375,7 @@
       def value(self):
         return self._value
 
-    with self.test_session():
+    with self.cached_session():
       s = State()
       op = s.increment(constant_op.constant(2, dtypes.int64))
       ret = self.evaluate(op)
@@ -389,7 +389,7 @@
 
     f = script_ops.py_func(
         do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(f), [])
 
   def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
@@ -417,21 +417,22 @@
     else:
       f = script_ops.py_func(raise_exception, [], [])
 
-    with self.test_session():
-      with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
-        self.evaluate(f)
+    with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
+      self.evaluate(f)
 
   def testExceptionHandling(self):
-    self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
-    self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
-    self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
-    self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
-    self._testExceptionHandling(NotImplementedError, errors.UnimplementedError)
+    with self.cached_session():
+      self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
+      self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
+      self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
+      self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
+      self._testExceptionHandling(NotImplementedError,
+                                  errors.UnimplementedError)
 
-    class WeirdError(Exception):
-      pass
+      class WeirdError(Exception):
+        pass
 
-    self._testExceptionHandling(WeirdError, errors.UnknownError)
+      self._testExceptionHandling(WeirdError, errors.UnknownError)
 
   # ----- Tests shared by py_func and eager_py_func -----
   def testCleanup(self):
@@ -452,7 +453,7 @@
           # (see #18292)
           _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
           _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
- 
+
     # Call garbage collector to enforce deletion.
     make_graphs()
     ops.reset_default_graph()
@@ -565,6 +566,18 @@
     dy_dx = gradients_impl.gradients(y, x)[0]
     self.assertEqual(self.evaluate(dy_dx), 6.0)
 
+  def testEagerGradientGraphTwoOutputs(self):
+
+    def f(x, y):
+      return x * y, x / y
+
+    x = constant_op.constant(3.0)
+    y = constant_op.constant(2.0)
+    fa, fb = script_ops.eager_py_func(f, inp=[x, y],
+                                      Tout=[dtypes.float32, dtypes.float32])
+    dy_dx = gradients_impl.gradients(fa + fb, x)[0]
+    self.assertEqual(self.evaluate(dy_dx), 2.5)
+
   @test_util.run_in_graph_and_eager_modes
   def testEagerGradientTapeMultipleArgs(self):
 
@@ -610,7 +623,7 @@
         func=log_huber, inp=[x, m], Tout=dtypes.float32)
     dy_dx = gradients_impl.gradients(y, x)[0]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Takes the first branch of log_huber.
       y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
       self.assertEqual(y, 1.0)
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 8e06e1a..8c84b2a 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -146,7 +146,7 @@
     self.assertAllEqual(expected, v)
 
   def testOneEpoch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.IdentityReader("test_reader")
       work_completed = reader.num_work_units_completed()
       produced = reader.num_records_produced()
@@ -180,7 +180,7 @@
       self.assertAllEqual(0, queued_length.eval())
 
   def testMultipleEpochs(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.IdentityReader("test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       enqueue = queue.enqueue_many([["DD", "EE"]])
@@ -201,7 +201,7 @@
         sess.run([key, value])
 
   def testSerializeRestore(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.IdentityReader("test_reader")
       produced = reader.num_records_produced()
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
@@ -256,7 +256,7 @@
         reader.restore_state(b"BOGUS" + state[5:]).run()
 
   def testReset(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.IdentityReader("test_reader")
       work_completed = reader.num_work_units_completed()
       produced = reader.num_records_produced()
@@ -307,7 +307,7 @@
     self.assertAllEqual(self._content[index], v)
 
   def testOneEpoch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.WholeFileReader("test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       queue.enqueue_many([self._filenames]).run()
@@ -323,7 +323,7 @@
         sess.run([key, value])
 
   def testInfiniteEpochs(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.WholeFileReader("test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       enqueue = queue.enqueue_many([self._filenames])
@@ -366,7 +366,7 @@
     return filenames
 
   def _testOneEpoch(self, files):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TextLineReader(name="test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -391,7 +391,7 @@
 
   def testSkipHeaderLines(self):
     files = self._CreateFiles()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -522,7 +522,7 @@
   # gap_bytes=hop_bytes-record_bytes
   def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
     hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.FixedLengthRecordReader(
           header_bytes=self._header_bytes,
           record_bytes=self._record_bytes,
@@ -549,7 +549,7 @@
                                 files,
                                 num_overlapped_records,
                                 encoding=None):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.FixedLengthRecordReader(
           header_bytes=self._header_bytes,
           record_bytes=self._record_bytes,
@@ -621,7 +621,7 @@
 
   def testOneEpoch(self):
     files = self._CreateFiles()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TFRecordReader(name="test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -640,7 +640,7 @@
 
   def testReadUpTo(self):
     files = self._CreateFiles()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TFRecordReader(name="test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       batch_size = 3
@@ -670,7 +670,7 @@
     options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
     files = self._CreateFiles(options)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TFRecordReader(name="test_reader", options=options)
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -687,7 +687,7 @@
     options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
     files = self._CreateFiles(options)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.TFRecordReader(name="test_reader", options=options)
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -752,7 +752,7 @@
     shutil.copy(path, self.db_path)
 
   def testReadFromFile(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.LMDBReader(name="test_read_from_file")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -770,7 +770,7 @@
         k, v = sess.run([key, value])
 
   def testReadFromSameFile(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
       reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
       filename_queue = input_lib.string_input_producer(
@@ -789,7 +789,7 @@
       coord.join(threads)
 
   def testReadFromFolder(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.LMDBReader(name="test_read_from_folder")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -807,7 +807,7 @@
         k, v = sess.run([key, value])
 
   def testReadFromFileRepeatedly(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
       filename_queue = input_lib.string_input_producer(
           [self.db_path], num_epochs=None)
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 068860d..ebb9872 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -44,7 +44,7 @@
     w.close()
 
   def testRecordInputSimple(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.generateTestData("basic", 1, 1)
 
       yield_op = data_flow_ops.RecordInput(
@@ -57,7 +57,7 @@
       self.assertEqual(sess.run(yield_op), b"0000000000")
 
   def testRecordInputSimpleGzip(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.generateTestData(
           "basic",
           1,
@@ -76,7 +76,7 @@
       self.assertEqual(sess.run(yield_op), b"0000000000")
 
   def testRecordInputSimpleZlib(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.generateTestData(
           "basic",
           1,
@@ -98,7 +98,7 @@
     files = 100
     records_per_file = 100
     batches = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.generateTestData("basic", files, records_per_file)
 
       records = data_flow_ops.RecordInput(
@@ -126,7 +126,7 @@
   def testDoesNotDeadlock(self):
     # Iterate multiple times to cause deadlock if there is a chance it can occur
     for _ in range(30):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         self.generateTestData("basic", 1, 1)
 
         records = data_flow_ops.RecordInput(
@@ -141,7 +141,7 @@
           sess.run(yield_op)
 
   def testEmptyGlob(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       record_input = data_flow_ops.RecordInput(file_pattern="foo")
       yield_op = record_input.get_yield_op()
       sess.run(variables.global_variables_initializer())
@@ -152,7 +152,7 @@
     files = 10
     records_per_file = 10
     batches = 2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.generateTestData("basic", files, records_per_file)
 
       records = data_flow_ops.RecordInput(
diff --git a/tensorflow/python/kernel_tests/reduce_join_op_test.py b/tensorflow/python/kernel_tests/reduce_join_op_test.py
index 663561c..3bb4986 100644
--- a/tensorflow/python/kernel_tests/reduce_join_op_test.py
+++ b/tensorflow/python/kernel_tests/reduce_join_op_test.py
@@ -113,7 +113,7 @@
       keep_dims: Whether or not to retain reduced dimensions.
       separator: The separator to use for joining.
     """
-    with self.test_session():
+    with self.cached_session():
       output = string_ops.reduce_join(
           inputs=input_array,
           axis=axis,
@@ -136,7 +136,7 @@
       axis: The indices to reduce.
       separator: The separator to use when joining.
     """
-    with self.test_session():
+    with self.cached_session():
       output = string_ops.reduce_join(
           inputs=input_array, axis=axis, keep_dims=False, separator=separator)
       output_keep_dims = string_ops.reduce_join(
@@ -234,7 +234,7 @@
     input_array = [["a"], ["b"]]
     truth = ["ab"]
     truth_shape = None
-    with self.test_session():
+    with self.cached_session():
       placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
       reduced = string_ops.reduce_join(placeholder, axis=0)
       output_array = reduced.eval(feed_dict={placeholder.name: input_array})
@@ -247,7 +247,7 @@
     truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"]
     truth_dim_one = ["thisisatest", "pleasedonotpanic"]
     truth_shape = None
-    with self.test_session():
+    with self.cached_session():
       placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
       reduced = string_ops.reduce_join(input_array, axis=placeholder)
       output_array_dim_zero = reduced.eval(feed_dict={placeholder.name: [0]})
@@ -298,7 +298,7 @@
         self._testMultipleReduceJoin(input_array, axis=permutation)
 
   def testInvalidReductionIndices(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"):
         string_ops.reduce_join(inputs="", axis=0)
       with self.assertRaisesRegexp(ValueError,
@@ -313,7 +313,7 @@
         string_ops.reduce_join(inputs=[[""]], axis=[0, 2])
 
   def testZeroDims(self):
-    with self.test_session():
+    with self.cached_session():
       inputs = np.zeros([0, 1], dtype=str)
 
       # Reduction that drops the dim of size 0.
@@ -326,7 +326,7 @@
       self.assertAllEqual([0], output_shape)
 
   def testInvalidArgsUnknownShape(self):
-    with self.test_session():
+    with self.cached_session():
       placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
       index_too_high = string_ops.reduce_join(placeholder, axis=1)
       duplicate_index = string_ops.reduce_join(placeholder, axis=[-1, 1])
@@ -336,7 +336,7 @@
         duplicate_index.eval(feed_dict={placeholder.name: [[""]]})
 
   def testInvalidArgsUnknownIndices(self):
-    with self.test_session():
+    with self.cached_session():
       placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
       reduced = string_ops.reduce_join(["test", "test2"], axis=placeholder)
 
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index ea78b58..496a452 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -61,7 +61,7 @@
     self.assertAllEqual(output.eval(), result)
 
   def testSimple(self):
-    with self.test_session():
+    with self.cached_session():
       self._check([3], [], [3])
       self._check([3], [0], [1])
       self._check([5, 3], [], [5, 3])
@@ -71,7 +71,7 @@
 
   def testZeros(self):
     """Check that reduced_shape does the right thing with zero dimensions."""
-    with self.test_session():
+    with self.cached_session():
       self._check([0], [], [0])
       self._check([0], [0], [1])
       self._check([0, 3], [], [0, 3])
@@ -84,7 +84,7 @@
       self._check([3, 0], [0, 1], [1, 1])
 
   def testNegAxes(self):
-    with self.test_session():
+    with self.cached_session():
       self._check([10, 10, 10], [-1], [10, 10, 1])
       self._check([10, 10, 10], [-1, 2], [10, 10, 1])
       self._check([10, 10, 10], [-1, -1], [10, 10, 1])
@@ -95,7 +95,7 @@
 class ReductionUnknownShape(test.TestCase):
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       for dtype, reductions in [(dtypes.float32,
                                  (math_ops.reduce_sum, math_ops.reduce_mean,
                                   math_ops.reduce_prod, math_ops.reduce_max,
@@ -617,7 +617,7 @@
   def testGradient(self):
     s = [2, 3, 4, 2]
     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_min(t, [1, 2])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -627,7 +627,7 @@
   def testGradient2(self):
     s = [2, 3, 4, 2]
     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_min(t, [1])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -637,7 +637,7 @@
   def testGradient3(self):
     s = [2, 3, 4, 2]
     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_min(t, [2])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -647,7 +647,7 @@
   def testGradient4(self):
     s = [2, 3, 4, 2]
     x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_min(t)
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -655,7 +655,7 @@
     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
 
   def testEmptyGradients(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.zeros([0, 3])
       y = math_ops.reduce_min(x, [1])
       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -744,7 +744,7 @@
   def testGradient(self):
     s = [2, 3, 4, 2]
     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [1, 2])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -754,7 +754,7 @@
   def testGradient2(self):
     s = [2, 3, 4, 2]
     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [1])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -764,7 +764,7 @@
   def testGradient3(self):
     s = [2, 3, 4, 2]
     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [2])
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -774,7 +774,7 @@
   def testGradient4(self):
     s = [2, 3, 4, 2]
     x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
-    with self.test_session():
+    with self.cached_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t)
       jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -782,7 +782,7 @@
     self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
 
   def testEmptyGradients(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.zeros([0, 3])
       y = math_ops.reduce_max(x, [1])
       error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -960,7 +960,7 @@
 
   def testStringReduce(self):
     # Test case for GitHub issue 18712
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = math_ops.count_nonzero(constant_op.constant(["test"]))
       self.assertAllClose(sess.run(v), 1)
 
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 5daae1b..e81f562 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -18,37 +18,77 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
 
 
-class RegexFullMatchOpTest(test.TestCase):
+@parameterized.parameters(
+    (gen_string_ops.regex_full_match),
+    (gen_string_ops.static_regex_full_match))
+class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
 
-  def testRegexFullMatch(self):
+  def testRegexFullMatch(self, op):
     values = ["abaaba", "abcdabcde"]
-    with self.test_session():
-      input_vector = constant_op.constant(values, dtypes.string)
-      matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+    with self.cached_session():
+      input_tensor = constant_op.constant(values, dtypes.string)
+      matched = op(input_tensor, "a.*a").eval()
       self.assertAllEqual([True, False], matched)
 
-  def testEmptyMatch(self):
-    values = ["abc", "1"]
+  def testRegexFullMatchTwoDims(self, op):
+    values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
     with self.test_session():
-      input_vector = constant_op.constant(values, dtypes.string)
-      matched = string_ops.regex_full_match(input_vector, "").eval()
+      input_tensor = constant_op.constant(values, dtypes.string)
+      matched = op(input_tensor, "a.*a").eval()
+      self.assertAllEqual([[True, False], [True, False]], matched)
+
+  def testEmptyMatch(self, op):
+    values = ["abc", "1"]
+    with self.cached_session():
+      input_tensor = constant_op.constant(values, dtypes.string)
+      matched = op(input_tensor, "").eval()
       self.assertAllEqual([False, False], matched)
 
-  def testInvalidPattern(self):
+  def testInvalidPattern(self, op):
     values = ["abc", "1"]
-    with self.test_session():
-      input_vector = constant_op.constant(values, dtypes.string)
+    with self.cached_session():
+      input_tensor = constant_op.constant(values, dtypes.string)
       invalid_pattern = "A["
-      matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+      matched = op(input_tensor, invalid_pattern)
       with self.assertRaisesOpError("Invalid pattern"):
         matched.eval()
 
 
+class RegexFullMatchOpTest(test.TestCase):
+
+  def testRegexFullMatchDelegation(self):
+    with compat.forward_compatibility_horizon(2018, 11, 1):
+      with self.test_session():
+        input_tensor = constant_op.constant("foo", dtypes.string)
+        pattern = "[a-z]"
+        op = string_ops.regex_full_match(input_tensor, pattern)
+        self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
+
+        pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+        op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
+        self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
+
+  def testStaticRegexFullMatchDelegation(self):
+    with compat.forward_compatibility_horizon(2018, 11, 20):
+      with self.test_session():
+        input_tensor = constant_op.constant("foo", dtypes.string)
+        pattern = "[a-z]*"
+        op = string_ops.regex_full_match(input_tensor, pattern)
+        self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
+
+        pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+        op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
+        self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index f0e84b8f..feac3a8 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -20,7 +20,6 @@
 
 from absl.testing import parameterized
 
-from tensorflow.python.compat import compat
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import gen_string_ops
@@ -100,22 +99,20 @@
       (as_tensor, as_string),
       (as_tensor, as_tensor))
   def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
-    with compat.forward_compatibility_horizon(2018, 10, 11):
-      with self.test_session():
-        input_vector = constant_op.constant("foo", dtypes.string)
-        pattern = pattern_fn("[a-z]")
-        replace = rewrite_fn(".")
-        op = string_ops.regex_replace(input_vector, pattern, replace)
-        self.assertTrue(op.name.startswith("RegexReplace"))
+    with self.test_session():
+      input_vector = constant_op.constant("foo", dtypes.string)
+      pattern = pattern_fn("[a-z]")
+      replace = rewrite_fn(".")
+      op = string_ops.regex_replace(input_vector, pattern, replace)
+      self.assertTrue(op.name.startswith("RegexReplace"))
 
   def testStaticRegexReplaceDelegation(self):
-    with compat.forward_compatibility_horizon(2018, 10, 11):
-      with self.test_session():
-        input_vector = constant_op.constant("foo", dtypes.string)
-        pattern = "[a-z]"
-        replace = "."
-        op = string_ops.regex_replace(input_vector, pattern, replace)
-        self.assertTrue(op.name.startswith("StaticRegexReplace"))
+    with self.test_session():
+      input_vector = constant_op.constant("foo", dtypes.string)
+      pattern = "[a-z]"
+      replace = "."
+      op = string_ops.regex_replace(input_vector, pattern, replace)
+      self.assertTrue(op.name.startswith("StaticRegexReplace"))
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 657d92f..a45a325 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -104,7 +104,7 @@
   # The gradient test for ReLU is a bit tricky as the derivative is not well
   # defined at around zero and we want to avoid that in terms of input values.
   def testGradientFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -149,7 +149,7 @@
         self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4)
 
   def testGradientFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -166,7 +166,7 @@
     self.assertLess(err, 1e-10)
 
   def testGradGradFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -183,7 +183,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradGradFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -201,7 +201,7 @@
     self.assertLess(err, 1e-10)
 
   def testGradientScalar(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = variables.Variable(100.)
       y = nn_ops.relu(x)
       loss = y**2
@@ -249,7 +249,7 @@
   # not well defined at around zero and six and we want to avoid that
   # in terms of input values.
   def testGradientFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
           shape=[2, 5],
@@ -265,7 +265,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradientFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
           shape=[2, 5],
@@ -313,7 +313,7 @@
           use_gpu=True)
 
   def testGradientFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
       x = constant_op.constant(x_val, name="x")
       y = nn_ops.elu(x, name="elu")
@@ -324,7 +324,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradientFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
       x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
       y = nn_ops.elu(x, name="elu")
@@ -335,7 +335,7 @@
     self.assertLess(err, 1e-6)
 
   def testGradGrad(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtype=dtypes.float32)
       elu = nn_ops.elu(x)
       g, = gradients_impl.gradients(elu, x)
@@ -346,7 +346,7 @@
         self.assertLess(err, 1e-4)
 
   def testGradGradFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -363,7 +363,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradGradFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -415,7 +415,7 @@
           use_gpu=True)
 
   def testGradientFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
       x = constant_op.constant(x_val, name="x")
       y = nn_ops.selu(x, name="selu")
@@ -426,7 +426,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradientFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
       x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
       y = nn_ops.selu(x, name="selu")
@@ -437,7 +437,7 @@
     self.assertLess(err, 1e-6)
 
   def testGradGradFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -454,7 +454,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradGradFloat64(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -503,7 +503,7 @@
             use_gpu=True)
 
   def testNumbersWithAxis0(self):
-    with self.test_session():
+    with self.cached_session():
       crelu = nn_ops.crelu(
           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
       tf_relu = crelu.eval()
@@ -512,7 +512,7 @@
       self.assertAllEqual(np_crelu, tf_relu)
 
   def testNumbersWithAxis1(self):
-    with self.test_session():
+    with self.cached_session():
       crelu = nn_ops.crelu(
           np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
       tf_relu = crelu.eval()
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index ef9b439..ca3ff1d 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -94,7 +94,7 @@
   def testFloatReshapeGradThreeDimensions(self):
     x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32)
     s = list(np.shape(x))
-    with self.test_session():
+    with self.cached_session():
       input_tensor = constant_op.constant(x)
       reshape_out = array_ops.reshape(input_tensor, [1, 8, 3])
       err = gradient_checker.compute_gradient_error(
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index d0ed089..f90545f 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -54,7 +54,7 @@
     self.assertEqual(0, len(gc.garbage))
 
   def testHandleDtypeShapeMatch(self):
-    with self.test_session():
+    with self.cached_session():
       handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
       with self.assertRaises(ValueError):
         resource_variable_ops.assign_variable_op(
@@ -123,7 +123,7 @@
       self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy()))
 
   def testGraphDeepCopy(self):
-    with self.test_session():
+    with self.cached_session():
       init_value = np.ones((4, 4, 4))
       variable = resource_variable_ops.ResourceVariable(init_value,
                                                         name="init")
@@ -145,13 +145,13 @@
                    # variable graph.
 
   def testFetchHandle(self):
-    with self.test_session():
+    with self.cached_session():
       handle = resource_variable_ops.var_handle_op(
           dtype=dtypes.int32, shape=[1], name="foo")
       self.assertGreater(len(handle.eval()), 0)
 
   def testCachedValueReadBeforeWrite(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0")
       sess.run(v.initializer)
       value, _ = sess.run([v, v.assign_add(1.0)])
@@ -492,7 +492,7 @@
 
   # TODO(alive): how should this work in Eager mode?
   def testInitFn(self):
-    with self.test_session():
+    with self.cached_session():
       v = resource_variable_ops.ResourceVariable(
           initial_value=lambda: 1, dtype=dtypes.float32)
       self.assertEqual(v.handle.op.colocation_groups(),
@@ -569,11 +569,11 @@
     self.assertEqual(2.0, self.evaluate(v.value()))
 
   def testVariableDefInitializedInstances(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v_def = resource_variable_ops.ResourceVariable(
           initial_value=constant_op.constant(3.0)).to_proto()
 
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       # v describes a VariableDef-based variable without an initial value.
       v = resource_variable_ops.ResourceVariable(variable_def=v_def)
       self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -584,7 +584,7 @@
       self.assertEqual(1.0, v.initialized_value().eval())
 
     v_def.ClearField("initial_value_name")
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       # Restoring a legacy VariableDef proto that does not have
       # initial_value_name set should still work.
       v = resource_variable_ops.ResourceVariable(variable_def=v_def)
@@ -615,17 +615,16 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testSparseRead(self):
-    with self.test_session():
-      init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
-      v = resource_variable_ops.ResourceVariable(
-          constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
-      self.evaluate(variables.global_variables_initializer())
+    init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
+    v = resource_variable_ops.ResourceVariable(
+        constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
+    self.evaluate(variables.global_variables_initializer())
 
-      value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
-      self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
+    value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
+    self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value)
 
   def testToFromProto(self):
-    with self.test_session():
+    with self.cached_session():
       v = resource_variable_ops.ResourceVariable(1.0)
       variables.global_variables_initializer().run()
 
@@ -686,7 +685,7 @@
         handle, ignore_lookup_error=True))
 
   def testAssignDifferentShapes(self):
-    with self.test_session() as sess, variable_scope.variable_scope(
+    with self.cached_session() as sess, variable_scope.variable_scope(
         "foo", use_resource=True):
       var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32)
       placeholder = array_ops.placeholder(dtypes.float32)
@@ -728,7 +727,7 @@
         _ = w.value().op.get_attr("_class")
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       v = resource_variable_ops.ResourceVariable(300.0, name="var4")
       variables.global_variables_initializer().run()
 
@@ -746,7 +745,7 @@
         resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval()
 
   def testSharedNameWithNamescope(self):
-    with self.test_session():
+    with self.cached_session():
       with ops.name_scope("foo"):
         v = resource_variable_ops.ResourceVariable(300.0, name="var6")
         self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
@@ -774,7 +773,7 @@
           str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape))
 
   def testSetInitialValue(self):
-    with self.test_session():
+    with self.cached_session():
       # Initialize variable with a value different from the initial value passed
       # in the constructor.
       v = resource_variable_ops.ResourceVariable(2.0)
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index 9beb615..8fc71e0 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -120,7 +120,7 @@
     batch_axis = 2
     seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
 
-    with self.test_session():
+    with self.cached_session():
       input_t = constant_op.constant(x, shape=x.shape)
       seq_lengths_t = constant_op.constant(seq_lengths, shape=seq_lengths.shape)
       reverse_sequence_out = array_ops.reverse_sequence(
@@ -171,7 +171,7 @@
           seq_axis=0,
           batch_axis=3)
 
-    with self.test_session():
+    with self.cached_session():
       inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3))
       seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,))
       output = array_ops.reverse_sequence(
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 562d11f..a28cdc3 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -197,7 +197,7 @@
     else:
       inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       outputs, state = rnn.dynamic_rnn(
           cell, inputs, dtype=dtypes.float32, sequence_length=[4])
       if not in_eager_mode:
@@ -217,7 +217,7 @@
     else:
       inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       outputs, state = rnn.dynamic_rnn(
           cell, inputs, dtype=dtypes.float32, sequence_length=[4])
       if not in_eager_mode:
@@ -246,7 +246,7 @@
     else:
       inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
 
-    with self.test_session() as sess:
+    with self.cached_session(use_gpu=True) as sess:
       outputs, state = rnn.dynamic_rnn(
           cell, inputs, dtype=dtypes.float32, sequence_length=[4])
       state = (state[0], state[1].stack())
@@ -321,7 +321,7 @@
     self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
 
   def testRNNWithKerasSimpleRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_shape = 10
       output_shape = 5
       timestep = 4
@@ -354,7 +354,7 @@
       self.assertEqual(len(state), batch)
 
   def testRNNWithKerasGRUCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_shape = 10
       output_shape = 5
       timestep = 4
@@ -387,7 +387,7 @@
       self.assertEqual(len(state), batch)
 
   def testRNNWithKerasLSTMCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_shape = 10
       output_shape = 5
       timestep = 4
@@ -424,7 +424,7 @@
       self.assertEqual(len(state[1]), batch)
 
   def testRNNWithStackKerasCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_shape = 10
       output_shape = 5
       timestep = 4
@@ -465,7 +465,7 @@
         self.assertEqual(len(s), batch)
 
   def testStaticRNNWithKerasSimpleRNNCell(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       input_shape = 10
       output_shape = 5
       timestep = 4
@@ -567,7 +567,7 @@
         rnn_cell_impl.GRUCell(
             32, kernel_initializer="ones", dtype=dtypes.float32)
     ]:
-      with self.test_session():
+      with self.cached_session():
         x = keras.Input((None, 5))
         layer = keras.layers.RNN(cell)
         y = layer(x)
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f2f3023..86e063c 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -294,7 +294,7 @@
     self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
 
     expected_result = np.zeros([2, 2], dtype=np.int32)
-    with self.test_session():
+    with self.cached_session():
       ref.initializer.run()
       self.assertAllEqual(expected_result, scatter_update.eval())
 
@@ -409,7 +409,7 @@
     expected = np.array([b"", b"one", b"", b"three", b"four",
                          b"", b"", b"seven"])
     scatter = self.scatter_nd(indices, updates, shape=(8,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = sess.run(scatter)
       self.assertAllEqual(expected, result)
 
@@ -420,7 +420,7 @@
                                    dtype=dtypes.string)
     expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
     scatter = self.scatter_nd(indices, updates, shape=(8,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = sess.run(scatter)
       self.assertAllEqual(expected, result)
 
@@ -432,7 +432,7 @@
     expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
                 np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
     scatter = self.scatter_nd(indices, updates, shape=(8,))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       result = sess.run(scatter)
       self.assertTrue(np.array_equal(result, expected[0]) or
                       np.array_equal(result, expected[1]))
@@ -451,7 +451,7 @@
     scatter = self.scatter_nd(indices, updates, shape)
     self.assertAllEqual(scatter.get_shape().as_list(), shape)
     expected_result = np.zeros([2, 2], dtype=np.int32)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_result, scatter.eval())
 
   def testUndefinedIndicesShape(self):
@@ -486,7 +486,7 @@
     updates = array_ops.placeholder(dtypes.int32, shape=None)
     shape = constant_op.constant([0, 3, 2], dtypes.int32)
 
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError(
           "Indices and updates specified for empty output"):
         self.scatter_nd(indices, updates, shape).eval(feed_dict={
@@ -500,7 +500,7 @@
     shape = constant_op.constant([0], dtypes.int32)
     scatter = self.scatter_nd(indices, updates, shape)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertEqual(scatter.eval().size, 0)
 
   def testRank3InvalidShape1(self):
@@ -531,7 +531,7 @@
         [outputs], [updates, input_], [grad_vals])
     expected_updates_grad = np.array([1, 4], dtype=np.float64)
     expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
       if self.non_aliasing_add_test:
         self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -548,7 +548,7 @@
         [outputs], [updates, input_], [grad_vals])
     expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
     expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
       if self.non_aliasing_add_test:
         self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -570,7 +570,7 @@
         [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
     expected_input_grad = np.array(
         [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
       if self.non_aliasing_add_test:
         self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -607,7 +607,7 @@
             [[[[1, 2], [3, 4]]]],
             [[[[5, 6], [7, 8]]]]
         ]]], dtype=np.float64)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
       if self.non_aliasing_add_test:
         self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -616,33 +616,33 @@
     indices = array_ops.zeros([100000, 1], dtypes.int32)
     values = np.random.randn(100000)
     shape = [1]
-    with self.test_session():
+    with self.cached_session():
       val = self.scatter_nd(indices, values, shape).eval()
     self.assertAllClose([np.sum(values)], val)
 
   def testSmokeScatterNdBatch2DSliceDim2(self):
-    with self.test_session():
+    with self.cached_session():
       indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
       values = array_ops.zeros([3, 5, 7])
       shape = [4, 6, 7]
       self.scatter_nd(indices, values, shape).eval()
 
   def testSmokeScatterNdBatch1DSliceDim2(self):
-    with self.test_session():
+    with self.cached_session():
       indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
       values = array_ops.zeros([0, 7])
       shape = [4, 6, 7]
       self.scatter_nd(indices, values, shape).eval()
 
   def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
-    with self.test_session():
+    with self.cached_session():
       indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
       values = array_ops.zeros([1, 6, 7, 8, 9])
       shape = [3, 4, 5, 6, 7, 8, 9]
       self.scatter_nd(indices, values, shape).eval()
 
   def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
-    with self.test_session():
+    with self.cached_session():
       indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
       values = array_ops.zeros([1, 2, 6, 7, 8, 9])
       shape = [3, 4, 5, 6, 7, 8, 9]
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index a82855d..ce507e4 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -177,7 +177,7 @@
 
   def testSegmentIdsInvalid1(self):
     shape = [4, 4]
-    with self.test_session():
+    with self.cached_session():
       tf_x, _ = self._input(shape)
       indices = [-1, -1, 0, 0]
       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -188,7 +188,7 @@
 
   def testSegmentIdsInvalid2(self):
     shape = [4, 4]
-    with self.test_session():
+    with self.cached_session():
       tf_x, _ = self._input(shape)
       indices = [0, 1, 0, 1]
       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -197,7 +197,7 @@
 
   def testSegmentIdsInvalid3(self):
     shape = [4, 4]
-    with self.test_session():
+    with self.cached_session():
       tf_x, _ = self._input(shape)
       indices = [0, 1, 2, 0]
       s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -233,7 +233,7 @@
         math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min,
         math_ops.segment_max
     ]:
-      with self.test_session():
+      with self.cached_session():
         tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
         s = tf_op(data=tf_x, segment_ids=indices)
         jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -736,7 +736,7 @@
     segment_indices = [0, 1, 2, 2]
     num_indices = len(segment_indices)
     for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
-      with self.test_session():
+      with self.cached_session():
         tf_indices, _, tf_x, np_x = self._sparse_input(
             shape, num_indices, dtype=dtypes_lib.float64)
         s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
@@ -758,7 +758,7 @@
         math_ops.sparse_segment_sum_with_num_segments,
         math_ops.sparse_segment_mean_with_num_segments,
     ]:
-      with self.test_session():
+      with self.cached_session():
         tf_indices, _, tf_x, np_x = self._sparse_input(
             shape, num_indices, dtype=dtypes_lib.float64)
         s = tf_op(
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
index 678016b..03e1ae8 100644
--- a/tensorflow/python/kernel_tests/session_ops_test.py
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -31,7 +31,7 @@
 class SessionOpsTest(test.TestCase):
 
   def testHandleBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -45,7 +45,7 @@
       self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
 
   def testHandleEval(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -57,7 +57,7 @@
       self.assertEqual(50, h.eval())
 
   def testHandleAndValue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle and a value.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -70,7 +70,7 @@
       self.assertEqual(500, v)
 
   def testHandleCond(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle and a value
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -90,7 +90,7 @@
       self.assertEqual(5000, result)
 
   def testHandleForLoop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize a handle.
       a = constant_op.constant(0)
       h = session_ops.get_session_handle(a)
@@ -107,7 +107,7 @@
       self.assertEqual(100, h.eval())
 
   def testHandleWhileLoop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize a handle.
       a = constant_op.constant(0)
       h = session_ops.get_session_handle(a)
@@ -127,7 +127,7 @@
       self.assertEqual(101, h.eval())
 
   def testHandleMover(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -148,7 +148,7 @@
         self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
 
   def testHandleDelete(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -157,7 +157,7 @@
       sess.run(h).delete()
 
   def testHandleDeleteRaw(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Return a handle.
       a = constant_op.constant(10)
       b = constant_op.constant(5)
@@ -171,7 +171,7 @@
       sess.run(x, feed_dict={f: raw_h})
 
   def testMultiDevices(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with ops.device(test.gpu_device_name()):
         a = constant_op.constant(1.0)
         a_handle = sess.run(session_ops.get_session_handle(a))
@@ -189,7 +189,7 @@
       self.assertEqual(3.0, c_handle.eval())
 
   def testHandleGC(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # initial values live on CPU
       with ops.device("/cpu:0"):
         one = constant_op.constant(1, dtype=dtypes.float32)
@@ -213,7 +213,7 @@
                        add_h2: x_handle.handle})
 
   def testHandlePlacement(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant(1.0)
       a_handle_op = session_ops.get_session_handle(a)
       b = constant_op.constant(2.0)
@@ -233,7 +233,7 @@
       self.assertEqual(3.0, c_handle.eval())
 
   def testFeedOneHandleDirectly(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant(10.0)
       b = constant_op.constant(5.0)
       c = math_ops.multiply(a, b)
@@ -244,7 +244,7 @@
       self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
 
   def testDirectHandleFeedOverlappingWithFetches(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant(10.0)
       b = constant_op.constant(5.0)
       c = math_ops.multiply(a, b)
@@ -270,7 +270,7 @@
       self.assertAllClose(50.0, d_val)
 
   def testFeedTwoHandlesDirectly(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant(10.0)
       b = constant_op.constant(5.0)
       c = math_ops.multiply(a, b)
@@ -284,7 +284,7 @@
       self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
 
   def testFeedHandleToVariableDirectly(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = variables.Variable(12.0)
       inc_a = state_ops.assign_add(a, 2.0)
       b = math_ops.add(a, 5.0)
diff --git a/tensorflow/python/kernel_tests/sets_test.py b/tensorflow/python/kernel_tests/sets_test.py
index 52b7238..8335e9c 100644
--- a/tensorflow/python/kernel_tests/sets_test.py
+++ b/tensorflow/python/kernel_tests/sets_test.py
@@ -158,7 +158,7 @@
     for op in ops:
       self.assertEqual(None, op.get_shape().dims)
       self.assertEqual(dtypes.int32, op.dtype)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       results = sess.run(ops)
     self.assertAllEqual(results[0], results[1])
     return results[0]
@@ -477,7 +477,7 @@
     dynamic_values_shape_ops = []
     static_indices_shape = None
     static_values_shape = None
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       for op in ops:
         if static_indices_shape is None:
           static_indices_shape = op.indices.get_shape()
@@ -533,7 +533,7 @@
 
   def _set_intersection_count(self, a, b):
     op = sets.set_size(sets.set_intersection(a, b))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       return sess.run(op)
 
   def test_set_difference_multirow_2d(self):
@@ -971,7 +971,7 @@
 
   def _set_difference_count(self, a, b, aminusb=True):
     op = sets.set_size(sets.set_difference(a, b, aminusb))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       return sess.run(op)
 
   def test_set_union_multirow_2d(self):
@@ -1220,7 +1220,7 @@
 
   def _set_union_count(self, a, b):
     op = sets.set_size(sets.set_union(a, b))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       return sess.run(op)
 
   def _assert_set_operation(self, expected_indices, expected_values,
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 34e34d9..0304dc3 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -158,7 +158,7 @@
   # Disabled because it takes too long to run, but manually verified
   # as passing at time of writing.
   def _test64BitOutput(self):
-    with self.test_session():
+    with self.cached_session():
       inp = array_ops.zeros([2**31])
       num_elements = array_ops.size_internal(
           inp, optimize=False, out_type=dtypes.int64)
@@ -166,7 +166,7 @@
 
     # Too large for tf.int32 output.
     with self.assertRaises(errors_impl.InvalidArgumentError):
-      with self.test_session():
+      with self.cached_session():
         inp = array_ops.zeros([2**31])
         num_elements = array_ops.size_internal(
             inp, optimize=False, out_type=dtypes.int32)
@@ -228,7 +228,7 @@
     self._compareExpandDimsAll(choice([2, 3, 5]), -4)
 
   def testExpandDimsErrors(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertRaises(ValueError, array_ops.expand_dims,
                         np.zeros([2, 3, 5]), -5)
       self.assertRaises(ValueError, array_ops.expand_dims,
@@ -239,7 +239,7 @@
                         [False, True, True], 4)
 
   def testExpandDimsGradient(self):
-    with self.test_session():
+    with self.cached_session():
       inp = constant_op.constant(
           np.random.rand(4, 2).astype("f"), dtype=dtypes.float32)
       squeezed = array_ops.expand_dims(inp, 1)
@@ -249,7 +249,7 @@
     self.assertLess(err, 1e-3)
 
   def testExpandDimsScalar(self):
-    with self.test_session():
+    with self.cached_session():
       inp = constant_op.constant(7)
       self.assertAllEqual([7], array_ops.expand_dims(inp, 0).eval())
       self.assertAllEqual([7], array_ops.expand_dims(inp, -1).eval())
@@ -375,7 +375,7 @@
                           np.zeros([1, 2, 1]), [2, 3])
 
   def testSqueezeGradient(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 2).astype("f")
       a = array_ops.reshape(inp, [4, 1, 2])
       squeezed = array_ops.squeeze(a, [])
@@ -385,7 +385,7 @@
     self.assertLess(err, 1e-3)
 
   def testSqueezeGradientWithSqueezeDims(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 2).astype("f")
       a = array_ops.reshape(inp, [4, 1, 2, 1])
       squeezed = array_ops.squeeze(a, [1])
@@ -395,7 +395,7 @@
     self.assertLess(err, 1e-3)
 
   def testSqueezeWithUnknownShape(self):
-    with self.test_session():
+    with self.cached_session():
       a = array_ops.placeholder(dtypes.float32, shape=[2, None])
 
       squeezed = array_ops.squeeze(a, [1])
@@ -433,7 +433,7 @@
       self.assertTrue((result == np.tile(inp, (1, 4))).all())
 
   def testIdentityTileAndGrad(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 1).astype(np.float32)
       a = constant_op.constant(inp)
       tiled = array_ops.tile(a, [1, 1])
@@ -443,7 +443,7 @@
     self.assertTrue((result == np.tile(inp, (1, 1))).all())
 
   def testEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(2, 3).astype(np.float32)
       a = constant_op.constant(inp)
       tiled = array_ops.tile(a, [5, 0])
@@ -453,7 +453,7 @@
 
   def testUnknownInputShape(self):
     """Importing can call _TileShape without shape of <multiples> known."""
-    with self.test_session():
+    with self.cached_session():
       inp = array_ops.placeholder(dtypes.float32)  # unknown shape
       multiples = constant_op.constant([1, 2, 3, 4], dtype=np.int32)
       tiled = array_ops.tile(inp, multiples)
@@ -503,7 +503,7 @@
       self.assertAllEqual(result, np.tile(inp, (1, 4)))
 
   def testInvalidDim(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 1).astype("f")
       a = constant_op.constant(
           [float(x) for x in inp.ravel(order="C")],
@@ -546,7 +546,7 @@
       self._RunAndVerifyResult(10, use_gpu=True)
 
   def testGradientSimpleReduction(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 1).astype("f")
       a = constant_op.constant(
           [float(x) for x in inp.flatten()], shape=[4, 1], dtype=dtypes.float32)
@@ -561,7 +561,7 @@
     self.assertAllClose(np.sum(grad_inp, axis=1).reshape(4, 1), result, 1e-3)
 
   def testGradientStridedReduction(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 2).astype("f")
       a = constant_op.constant(
           [float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -634,7 +634,7 @@
     self._RunAndVerifyGradientResult([2, 1, 3, 3, 2], [1, 3, 3, 1, 2])
 
   def testGradientStridedReductionGC(self):
-    with self.test_session():
+    with self.cached_session():
       inp = np.random.rand(4, 2).astype("f")
       a = constant_op.constant(
           [float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -647,7 +647,7 @@
                                   dtype=dtypes.float32)
     outputs = array_ops.gather(array_ops.tile(inputs, [3]),
                                [1, 5, 9, 3, 7, 2, 2, 2])
-    with self.test_session():
+    with self.cached_session():
       error = gradient_checker.compute_gradient_error(
           inputs, inputs.get_shape().as_list(),
           outputs, outputs.get_shape().as_list())
@@ -659,7 +659,7 @@
     inputs = array_ops.reshape(inputs, [-1, 1, 1])
     outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]),
                                [1, 5, 9, 3, 7, 2, 2, 2])
-    with self.test_session():
+    with self.cached_session():
       error = gradient_checker.compute_gradient_error(
           inputs, inputs.get_shape().as_list(),
           outputs, outputs.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 4a1fc1d..c08d322 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -26,6 +26,7 @@
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 
 
@@ -106,7 +107,7 @@
 
   def testScalarInput(self):
     input_val = 0
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test with constant input; shape inference fails.
       with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
         constant_op.constant(input_val)[:].get_shape()
@@ -120,7 +121,7 @@
 
   def testInvalidIndex(self):
     input_val = [1, 2]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Test with constant input; shape inference fails.
       with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
         constant_op.constant(input_val)[1:, 1:].get_shape()
@@ -260,6 +261,21 @@
       grad_actual = gradients_impl.gradients(out, inp)[0].eval()
     self.assertAllClose([0., 1., 1.], grad_actual)
 
+  def _testGradientVariableSize2D(self):
+    # Regression test for bug in slice. A low-level bug in Eigen was causing
+    # incorrect results for negative indices in multi-dimensional tensors.
+    # See b/114318298.
+    with self.test_session(use_gpu=True) as sess:
+      x = constant_op.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 7]])
+      loss1 = math_ops.reduce_sum(x[:-1, :-1] * 1.0)
+      loss2 = math_ops.reduce_sum(x[:-1][:, :-1])
+
+      g1 = gradients_impl.gradients(loss1, x)[0]
+      g2 = gradients_impl.gradients(loss2, x)[0]
+
+      g1_val, g2_val = sess.run([g1, g2])
+    self.assertAllEqual(g1_val, g2_val)
+
   def testGradientsAll(self):
     # Slice the middle square out of a 4x4 input
     self._testGradientSlice([4, 4], [1, 1], [2, 2])
@@ -276,6 +292,9 @@
     # Use -1 as a slice dimension.
     self._testGradientVariableSize()
 
+    # Use -1 as a slice dimension on a 2D tensor.
+    self._testGradientVariableSize2D()
+
   def testNotIterable(self):
     # NOTE(mrry): If we register __getitem__ as an overloaded
     # operator, Python will valiantly attempt to iterate over the
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index fbf1adb..e53347c 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -210,7 +210,7 @@
     self.assertEqual([3, 2, 4], op.get_shape())
 
   def testEmptyInput(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtypes.float32, shape=[0, 3])
       self.assertEqual(0, array_ops.size(x).eval())
       # reshape would raise if logits is empty
@@ -218,7 +218,7 @@
         nn_ops.softmax(x, axis=0).eval()
 
   def testDimTooLarge(self):
-    with self.test_session():
+    with self.cached_session():
       # Use placeholder to make sure we get runtime error instead of shape
       # inference error.
       dim = array_ops.placeholder_with_default(100, shape=[])
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index c0269db..afe3df6 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -72,7 +72,7 @@
           use_gpu=True)
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -88,7 +88,7 @@
     self.assertLess(err, 1e-4)
 
   def testGradGrad(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -105,7 +105,7 @@
     self.assertLess(err, 5e-5)
 
   def testGradGradGrad(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -123,7 +123,7 @@
     self.assertLess(err, 5e-5)
 
   def testNoInts(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           "No OpKernel was registered to support Op 'Softplus'"):
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index a5247ce..05a7c53 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -51,7 +51,7 @@
           use_gpu=True)
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -67,7 +67,7 @@
     self.assertLess(err, 1e-4)
 
   def testNoInts(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
           "No OpKernel was registered to support Op 'Softsign'"):
diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
index 2a9232b..e267c05 100644
--- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
@@ -551,7 +551,7 @@
   def _checkGrad(self, x, block_shape, paddings):
     block_shape = np.array(block_shape)
     paddings = np.array(paddings).reshape((len(block_shape), 2))
-    with self.test_session():
+    with self.cached_session():
       tf_x = ops.convert_to_tensor(x)
       tf_y = array_ops.space_to_batch_nd(tf_x, block_shape, paddings)
       epsilon = 1e-5
@@ -638,7 +638,7 @@
     t_paddings, t_crops = array_ops.required_space_to_batch_paddings(
         input_shape_placeholder, block_shape_placeholder,
         base_paddings_placeholder)
-    with self.test_session():
+    with self.cached_session():
       paddings_result = t_paddings.eval(assignments)
       crops_result = t_crops.eval(assignments)
     self.assertAllEqual(paddings_result, paddings_const)
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index d749843..4777203 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -61,14 +61,22 @@
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q")
     self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       name:'Q' op:'SparseConditionalAccumulator'
       attr { key: 'dtype' value { type: DT_FLOAT } }
       attr { key: 'shape' value { shape { unknown_rank: true} } }
       attr { key: 'container' value { s: '' } }
       attr { key: 'shared_name' value { s: '' } }
+      attr { key: 'reduction_type' value {s: 'MEAN'} }
       """, q.accumulator_ref.op.node_def)
 
+  def testConstructorWithInvalidArg(self):
+    with ops.Graph().as_default():
+      with self.assertRaises(ValueError):
+        data_flow_ops.SparseConditionalAccumulator(
+            dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
   def testConstructorWithShape(self):
     with ops.Graph().as_default():
       q = data_flow_ops.SparseConditionalAccumulator(
@@ -76,7 +84,8 @@
           name="Q",
           shape=tensor_shape.TensorShape([1, 5, 2, 8]))
     self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
-    self.assertProtoEquals("""
+    self.assertProtoEquals(
+        """
       name:'Q' op:'SparseConditionalAccumulator'
       attr { key: 'dtype' value { type: DT_FLOAT } }
       attr { key: 'shape' value { shape { dim {size: 1 }
@@ -86,23 +95,24 @@
       } } }
       attr { key: 'container' value { s: '' } }
       attr { key: 'shared_name' value { s: '' } }
+      attr { key: 'reduction_type' value {s: 'MEAN'} }
       """, q.accumulator_ref.op.node_def)
 
   def testAccumulatorSizeEmpty(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q")
       self.assertEqual(q.num_accumulated().eval(), 0)
 
   def testAccumulatorSetGlobalStep(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
       set_global_step_op = q.set_global_step(1)
       set_global_step_op.run()
 
   def testAccumulatorApplyGradFloat32(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
       accum_op = q.apply_indexed_slices_grad(
@@ -113,7 +123,7 @@
       self.assertEqual(q.num_accumulated().eval(), 1)
 
   def testDtypes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
 
       for i in range(len(dtypes)):
@@ -135,7 +145,7 @@
         self._assertEqual_nparray(sum_elems / len(elems), result, sess)
 
   def testAccumulatorMultipleAccumulators(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q_f32_0 = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
       q_f32_1 = data_flow_ops.SparseConditionalAccumulator(
@@ -164,8 +174,8 @@
         result = sess.run(accums[i].take_indexed_slices_grad(1))
         self._assertEqual_indexedslices(expected_tensors[i], result)
 
-  def testAccumulatorTakeGrad(self):
-    with self.test_session() as sess:
+  def testAccumulatorTakeGradMean(self):
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=())
 
@@ -180,12 +190,37 @@
 
       takeg_t = q.take_indexed_slices_grad(1)
       val = sess.run(takeg_t)
-      self.assertAllEqual(val.indices, [0, 1, 2])
-      self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]])
-      self.assertAllEqual(val.dense_shape, [-1, 2])
+      self.assertAllEqual([0, 1, 2], val.indices)
+      self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
+      self.assertAllEqual([-1, 2], val.dense_shape)
+
+  def testAccumulatorTakeGradSum(self):
+    with self.test_session() as sess:
+      q = data_flow_ops.SparseConditionalAccumulator(
+          dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
+
+      grad_indexed_slices = ops.IndexedSlices(
+          indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
+      accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
+      accum_op.run()
+      accum_op = q.apply_grad([0, 2],
+                              np.array([[0, 1], [3, 0]]).astype(np.float32),
+                              [3, 2])
+      accum_op.run()
+
+      takeg_t = q.take_indexed_slices_grad(1)
+      val = sess.run(takeg_t)
+      self.assertAllEqual([0, 1, 2], val.indices)
+      self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
+      self.assertAllEqual([-1, 2], val.dense_shape)
+
+  def testAccumulatorTakeGradInvalidReductionType(self):
+    with self.assertRaises(ValueError):
+      data_flow_ops.SparseConditionalAccumulator(
+          dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
 
   def testAccumulatorRepeatedTakeGrad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=())
 
@@ -222,8 +257,8 @@
       self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
       self.assertAllEqual(val.dense_shape, [-1, 2])
 
-  def testParallelApplyGrad(self):
-    with self.test_session() as sess:
+  def testParallelApplyGradMean(self):
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
       elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -253,9 +288,43 @@
           np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
           val, sess)
 
-  def testParallelTakeGrad(self):
+  def testParallelApplyGradSum(self):
     with self.test_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
+          dtypes_lib.float32,
+          name="Q",
+          shape=tensor_shape.TensorShape([2, 2]),
+          reduction_type="SUM")
+      elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+      accum_ops = []
+      for x in elems:
+        x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
+        accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
+      takeg_t = q.take_indexed_slices_grad(1)
+
+      def apply_indexed_slices_grad(accum_op):
+        sess.run(accum_op)
+
+      threads = [
+          self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
+          for o in accum_ops
+      ]
+
+      for thread in threads:
+        thread.start()
+      for thread in threads:
+        thread.join()
+
+      val = sess.run(takeg_t)
+
+      expected_val = 550.0
+      self._assertEqual_nparray(
+          np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
+          val, sess)
+
+  def testParallelTakeGrad(self):
+    with self.cached_session() as sess:
+      q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
       elems = [e + 1 for e in range(10)]
       accum_ops = []
@@ -293,7 +362,7 @@
             np.array([[0, 0], [elems[i], 0]]), results[i], sess)
 
   def testAccumulatorApplyAndBlockingTake(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
 
@@ -328,7 +397,7 @@
       sess.run(takeg_op)
 
   def testAccumulatorCancel(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32,
           name="Q",
@@ -347,7 +416,7 @@
       takeg_thread.join()
 
   def testNonVectorIndices(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
 
@@ -359,7 +428,7 @@
             grad_values=np.array([1, 2]).astype(np.float32)).run()
 
   def testZeroDimensionValues(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
 
@@ -369,7 +438,7 @@
             grad_indices=[0], grad_values=np.array(1).astype(np.float32)).run()
 
   def testWrongNonEmptyInputValues(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
 
@@ -380,7 +449,7 @@
             grad_values=np.array([[0, 1, 1]]).astype(np.float32)).run()
 
   def testDynamicNonVectorIndices(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
 
@@ -399,7 +468,7 @@
                  })
 
   def testDynamicWrongNonEmptyInputValues(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
 
@@ -417,7 +486,7 @@
                  })
 
   def testEmptyShapeApply(self):
-    with self.test_session():
+    with self.cached_session():
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([]))
 
@@ -442,7 +511,7 @@
       q.apply_grad(grad_indices=[0], grad_values=[1.0]).run()
 
   def testValidateShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=[2, 2, None])
 
@@ -537,7 +606,7 @@
             local_step=1).run()
 
   def testReturnShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=[2, None])
 
@@ -562,7 +631,7 @@
       self.assertAllEqual(val.dense_shape, [-1, 2, 2, 3])
 
   def testApplyGradtInt32IndicesAndShape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q = data_flow_ops.SparseConditionalAccumulator(
           dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
       accum_op = q.apply_grad(
diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
index ca7898d..6e0714d 100644
--- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
@@ -42,7 +42,7 @@
         'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
         'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_dense(self):
@@ -62,7 +62,7 @@
         'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
         'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_mixed_string_sparse(self):
@@ -76,7 +76,7 @@
         '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
         '55555_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_mixed_string_dense(self):
@@ -94,7 +94,7 @@
         '55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
         '999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_sparse_cross_dense(self):
@@ -111,7 +111,7 @@
             'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
             'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
         ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_integer_sparse_input(self):
@@ -127,7 +127,7 @@
             '333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
             '5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
         ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_permutation_3x3x3(self):
@@ -169,7 +169,7 @@
         'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
         'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_permutation_3x1x2(self):
@@ -188,7 +188,7 @@
         'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
         'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
     ]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_large_batch(self):
@@ -221,7 +221,7 @@
       ])
 
     expected_out = self._sparse_tensor(col_out)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_one_column_empty(self):
@@ -234,7 +234,7 @@
         self._sparse_tensor([], 1),
         self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_empty(sess.run(op))
 
   def test_some_columns_empty(self):
@@ -253,7 +253,7 @@
         'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
         'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
     ]], 2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_all_columns_empty(self):
@@ -266,7 +266,7 @@
         self._sparse_tensor([]),
         self._sparse_tensor([])
     ])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_empty(sess.run(op))
 
   def test_hashed_zero_bucket_no_hash_key(self):
@@ -277,7 +277,7 @@
     ])
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[1971693436396284976]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed_zero_bucket(self):
@@ -290,7 +290,7 @@
         hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[4847552627144134031]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   # TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -304,7 +304,7 @@
         num_buckets=100)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[83]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed_output(self):
@@ -318,7 +318,7 @@
         hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
     # Check actual hashed output to prevent unintentional hashing changes.
     expected_out = self._sparse_tensor([[31]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self._assert_sparse_tensor_equals(expected_out, sess.run(op))
 
   def test_hashed__has_no_collision(self):
@@ -344,7 +344,7 @@
             self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
         ],
         num_buckets=1000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       out = sess.run(op)
       self.assertEqual(6, len(out.values))
       self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index f50e39d..90009fc 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -130,7 +130,7 @@
 
   def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, delta,
                      name):
-    with self.test_session():
+    with self.cached_session():
       a = constant_op.constant(
           RandMatrix(
               3, 2, tr_a, round_bfloat=True), dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index fc39de1..79efee3 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -628,7 +628,7 @@
         else:
           np_ans = np.max(np_ans, axis=ra, keepdims=keep_dims)
 
-    with self.test_session():
+    with self.cached_session():
       if do_sum:
         tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes,
                                                     keep_dims)
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index 87a4eb9..c71746c 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -81,7 +81,7 @@
     self.assertAllClose(np_ans, tf_ans)
 
   def testZeroDefault(self):
-    with self.test_session():
+    with self.cached_session():
       x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
       self.assertAllEqual(x, [0, 0, 7, 0])
 
@@ -94,12 +94,12 @@
     self.assertAllClose(np_ans, tf_ans)
 
   def testBadShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
         _SparseToDense([1, 3], [[5], [3]], 1, -1)
 
   def testBadValue(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense([1, 3], [5], [[5], [3]], -1)
       with self.assertRaisesOpError(
           r"sparse_values has incorrect shape \[2,1\], "
@@ -107,20 +107,20 @@
         dense.eval()
 
   def testBadNumValues(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1)
       with self.assertRaisesOpError(
           r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
         dense.eval()
 
   def testBadDefault(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense([1, 3], [5], [1, 2], [0])
       with self.assertRaisesOpError("default_value should be a scalar"):
         dense.eval()
 
   def testOutOfBoundsIndicesWithWithoutValidation(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense(
           sparse_indices=[[1], [10]],
           output_size=[5],
@@ -140,7 +140,7 @@
         dense_without_validation.eval()
 
   def testRepeatingIndicesWithWithoutValidation(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense(
           sparse_indices=[[1], [1]],
           output_size=[5],
@@ -158,7 +158,7 @@
       dense_without_validation.eval()
 
   def testUnsortedIndicesWithWithoutValidation(self):
-    with self.test_session():
+    with self.cached_session():
       dense = _SparseToDense(
           sparse_indices=[[2], [1]],
           output_size=[5],
diff --git a/tensorflow/python/kernel_tests/sparsemask_op_test.py b/tensorflow/python/kernel_tests/sparsemask_op_test.py
index cf6c949..6f5dd45 100644
--- a/tensorflow/python/kernel_tests/sparsemask_op_test.py
+++ b/tensorflow/python/kernel_tests/sparsemask_op_test.py
@@ -34,7 +34,7 @@
     out_values = values[1:, :]
     out_indices = np.array([2, 3, 4], dtype=np.int32)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       values_tensor = ops.convert_to_tensor(values)
       indices_tensor = ops.convert_to_tensor(indices)
       mask_indices_tensor = ops.convert_to_tensor(mask_indices)
diff --git a/tensorflow/python/kernel_tests/string_join_op_test.py b/tensorflow/python/kernel_tests/string_join_op_test.py
index ce19333..e4371ab 100644
--- a/tensorflow/python/kernel_tests/string_join_op_test.py
+++ b/tensorflow/python/kernel_tests/string_join_op_test.py
@@ -28,7 +28,7 @@
     input1 = "a"
     input2 = [["b"], ["c"]]
 
-    with self.test_session():
+    with self.cached_session():
       output = string_ops.string_join([input0, input1])
       self.assertAllEqual(output.eval(), [b"aa", b"ba"])
 
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 075a320..9f013c2 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -27,7 +27,7 @@
   def testStringLength(self):
     strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       lengths = string_ops.string_length(strings)
       values = sess.run(lengths)
       self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index b6a0f45..b968e88 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -32,7 +32,7 @@
   def testStringSplit(self):
     strings = ["pigs on the wing", "animals"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -42,7 +42,7 @@
   def testStringSplitEmptyDelimiter(self):
     strings = ["hello", "hola", b"\xF0\x9F\x98\x8E"]  # Last string is U+1F60E
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings, delimiter="")
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
@@ -60,7 +60,7 @@
   def testStringSplitEmptyToken(self):
     strings = ["", " a", "b ", " c", " ", " d ", "  e", "f  ", "  g  ", "  "]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(
@@ -72,7 +72,7 @@
   def testStringSplitOnSetEmptyToken(self):
     strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings, delimiter=" .")
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(
@@ -84,7 +84,7 @@
   def testStringSplitWithDelimiter(self):
     strings = ["hello|world", "hello world"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertRaises(
           ValueError, string_ops.string_split, strings, delimiter=["|", ""])
 
@@ -106,7 +106,7 @@
   def testStringSplitWithDelimiterTensor(self):
     strings = ["hello|world", "hello world"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       delimiter = array_ops.placeholder(dtypes.string)
 
       tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -124,7 +124,7 @@
   def testStringSplitWithDelimitersTensor(self):
     strings = ["hello.cruel,world", "hello cruel world"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       delimiter = array_ops.placeholder(dtypes.string)
 
       tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -143,7 +143,7 @@
   def testStringSplitWithNoSkipEmpty(self):
     strings = ["#a", "b#", "#c#"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings, "#", skip_empty=False)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -152,7 +152,7 @@
       self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
       self.assertAllEqual(shape, [3, 3])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split(strings, "#")
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(values, [b"a", b"b", b"c"])
@@ -165,7 +165,7 @@
   def testSplitV2(self):
     strings = ["pigs on the wing", "animals"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -180,7 +180,7 @@
     # ['', '', '4', '5', '', '6', '']
     strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings, sep="<>")
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(
@@ -198,7 +198,7 @@
     # ['1', '2', '', '3', '']
     strings = ["1,2,3", "4,5,,6,"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings, sep=',')
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -215,7 +215,7 @@
     #['1', '2', '3']
     strings = ["1 2 3", "  4  5    6  "]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -231,7 +231,7 @@
     # ['4', '5,,6,']
     strings = ["1,2,3", "4,5,,6,"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -247,7 +247,7 @@
     # ['4', '5    6  ']
     strings = ["1 2 3", "  4  5    6  "]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tokens = string_ops.string_split_v2(strings, maxsplit=1)
       indices, values, shape = sess.run(tokens)
       self.assertAllEqual(indices, [[0, 0], [0, 1],
diff --git a/tensorflow/python/kernel_tests/string_strip_op_test.py b/tensorflow/python/kernel_tests/string_strip_op_test.py
index 30fd477..a96b714 100644
--- a/tensorflow/python/kernel_tests/string_strip_op_test.py
+++ b/tensorflow/python/kernel_tests/string_strip_op_test.py
@@ -28,7 +28,7 @@
   def test_string_strip(self):
     strings = ["pigs on the wing", "animals"]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = string_ops.string_strip(strings)
       output = sess.run(output)
       self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
@@ -37,7 +37,7 @@
     strings = [["pigs on the wing", "animals"],
                [" hello ", "\n\tworld \r \n"]]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = string_ops.string_strip(strings)
       output = sess.run(output)
       self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
@@ -46,7 +46,7 @@
   def test_string_strip_with_empty_strings(self):
     strings = [" hello ", "", "world ", " \t \r \n "]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output = string_ops.string_strip(strings)
       output = sess.run(output)
       self.assertAllEqual(output, [b"hello", b"", b"world", b""])
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 2c6064e..9cb0c9d 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -27,7 +27,7 @@
 class StringToHashBucketOpTest(test.TestCase):
 
   def testStringToOneHashBucketFast(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = array_ops.placeholder(dtypes.string)
       output = string_ops.string_to_hash_bucket_fast(input_string, 1)
       result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -35,7 +35,7 @@
       self.assertAllEqual([0, 0, 0], result)
 
   def testStringToHashBucketsFast(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = array_ops.placeholder(dtypes.string)
       output = string_ops.string_to_hash_bucket_fast(input_string, 10)
       result = output.eval(feed_dict={input_string: ['a', 'b', 'c', 'd']})
@@ -47,7 +47,7 @@
       self.assertAllEqual([9, 2, 2, 5], result)
 
   def testStringToOneHashBucketLegacyHash(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = array_ops.placeholder(dtypes.string)
       output = string_ops.string_to_hash_bucket(input_string, 1)
       result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -55,7 +55,7 @@
       self.assertAllEqual([0, 0, 0], result)
 
   def testStringToHashBucketsLegacyHash(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = array_ops.placeholder(dtypes.string)
       output = string_ops.string_to_hash_bucket(input_string, 10)
       result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -66,14 +66,14 @@
       self.assertAllEqual([8, 0, 7], result)
 
   def testStringToOneHashBucketStrongOneHashBucket(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = constant_op.constant(['a', 'b', 'c'])
       output = string_ops.string_to_hash_bucket_strong(
           input_string, 1, key=[123, 345])
       self.assertAllEqual([0, 0, 0], output.eval())
 
   def testStringToHashBucketsStrong(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = constant_op.constant(['a', 'b', 'c'])
       output = string_ops.string_to_hash_bucket_strong(
           input_string, 10, key=[98765, 132])
@@ -84,7 +84,7 @@
       self.assertAllEqual([4, 2, 8], output.eval())
 
   def testStringToHashBucketsStrongInvalidKey(self):
-    with self.test_session():
+    with self.cached_session():
       input_string = constant_op.constant(['a', 'b', 'c'])
       with self.assertRaisesOpError('Key must have 2 elements'):
         string_ops.string_to_hash_bucket_strong(
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
index cc4c21b..99ee25e 100644
--- a/tensorflow/python/kernel_tests/string_to_number_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -29,7 +29,7 @@
 class StringToNumberOpTest(test.TestCase):
 
   def _test(self, tf_type, good_pairs, bad_pairs):
-    with self.test_session():
+    with self.cached_session():
       # Build a small testing graph.
       input_string = array_ops.placeholder(dtypes.string)
       output = parsing_ops.string_to_number(
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 73ac71e..4d163a0 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.framework import errors_impl
@@ -25,7 +26,7 @@
 from tensorflow.python.platform import test
 
 
-class SubstrOpTest(test.TestCase):
+class SubstrOpTest(test.TestCase, parameterized.TestCase):
 
   def _testScalarString(self, dtype):
     test_string = b"Hello"
@@ -34,11 +35,22 @@
     expected_value = b"ell"
 
     substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      substr = substr_op.eval()
+      self.assertAllEqual(substr, expected_value)
+
+    # Negative position.
+    test_string = b"Hello"
+    position = np.array(-4, dtype)
+    length = np.array(3, dtype)
+    expected_value = b"ell"
+
+    substr_op = string_ops.substr(test_string, position, length)
     with self.test_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
-    # position is equal to the length of string.
+    # Position is equal to the length of string.
     test_string = b""
     position = np.array(0, dtype)
     length = np.array(2, dtype)
@@ -49,6 +61,17 @@
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
+    # Negative position magnitude is equal to the length of string.
+    test_string = b"yo"
+    position = np.array(-2, dtype)
+    length = np.array(1, dtype)
+    expected_value = b"y"
+
+    substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      substr = substr_op.eval()
+      self.assertAllEqual(substr, expected_value)
+
   def _testVectorStrings(self, dtype):
     test_string = [b"Hello", b"World"]
     position = np.array(1, dtype)
@@ -60,6 +83,17 @@
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
+    # Negative position.
+    test_string = [b"Hello", b"World"]
+    position = np.array(-4, dtype)
+    length = np.array(3, dtype)
+    expected_value = [b"ell", b"orl"]
+
+    substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      substr = substr_op.eval()
+      self.assertAllEqual(substr, expected_value)
+
   def _testMatrixStrings(self, dtype):
     test_string = [[b"ten", b"eleven", b"twelve"],
                    [b"thirteen", b"fourteen", b"fifteen"],
@@ -74,17 +108,31 @@
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
+    # Negative position
+    test_string = [[b"ten", b"eleven", b"twelve"],
+                   [b"thirteen", b"fourteen", b"fifteen"],
+                   [b"sixteen", b"seventeen", b"eighteen"]]
+    position = np.array(-2, dtype)
+    length = np.array(2, dtype)
+    expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
+                      [b"en", b"en", b"en"]]
+
+    substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      substr = substr_op.eval()
+      self.assertAllEqual(substr, expected_value)
+
   def _testElementWisePosLen(self, dtype):
     test_string = [[b"ten", b"eleven", b"twelve"],
                    [b"thirteen", b"fourteen", b"fifteen"],
                    [b"sixteen", b"seventeen", b"eighteen"]]
-    position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
-    length = np.array([[2, 3, 4], [4, 3, 2], [5, 5, 5]], dtype)
-    expected_value = [[b"en", b"eve", b"lve"], [b"hirt", b"urt", b"te"],
-                      [b"ixtee", b"vente", b"hteen"]]
+    position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
+    length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
+    expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+                      [b"xteen", b"vente", b"hteen"]]
 
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -94,33 +142,33 @@
                    [b"thirteen", b"fourteen", b"fifteen"],
                    [b"sixteen", b"seventeen", b"eighteen"],
                    [b"nineteen", b"twenty", b"twentyone"]]
-    position = np.array([1, 2, 3], dtype)
+    position = np.array([1, -4, 3], dtype)
     length = np.array([1, 2, 3], dtype)
-    expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
-                      [b"i", b"ve", b"hte"], [b"i", b"en", b"nty"]]
+    expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+                      [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
     # Broadcast input string onto pos/len
     test_string = [b"thirteen", b"fourteen", b"fifteen"]
-    position = np.array([[1, 2, 3], [3, 2, 1], [5, 5, 5]], dtype)
+    position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
     length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
-    expected_value = [[b"hir", b"ur", b"t"], [b"r", b"ur", b"ift"],
-                      [b"ee", b"ee", b"en"]]
+    expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+                      [b"ee", b"ee", b"ft"]]
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
     # Test 1D broadcast
     test_string = b"thirteen"
-    position = np.array([1, 5, 7], dtype)
+    position = np.array([1, -5, 7], dtype)
     length = np.array([3, 2, 1], dtype)
-    expected_value = [b"hir", b"ee", b"n"]
+    expected_value = [b"hir", b"rt", b"n"]
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       substr = substr_op.eval()
       self.assertAllEqual(substr, expected_value)
 
@@ -128,10 +176,8 @@
     test_string = [[b"ten", b"eleven", b"twelve"],
                    [b"thirteen", b"fourteen", b"fifteen"],
                    [b"sixteen", b"seventeen", b"eighteen"]]
-    position = np.array([1, 2, 3, 4], dtype)
+    position = np.array([1, 2, -3, 4], dtype)
     length = np.array([1, 2, 3, 4], dtype)
-    expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
-                      [b"i", b"ve", b"hte"]]
     with self.assertRaises(ValueError):
       substr_op = string_ops.substr(test_string, position, length)
 
@@ -141,6 +187,15 @@
     position = np.array(7, dtype)
     length = np.array(3, dtype)
     substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      with self.assertRaises(errors_impl.InvalidArgumentError):
+        substr = substr_op.eval()
+
+    # Scalar/Scalar (with negative)
+    test_string = b"Hello"
+    position = np.array(-7, dtype)
+    length = np.array(3, dtype)
+    substr_op = string_ops.substr(test_string, position, length)
     with self.test_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
@@ -150,16 +205,16 @@
     position = np.array(4, dtype)
     length = np.array(1, dtype)
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
 
-    # Negative pos
-    test_string = b"Hello"
-    position = np.array(-1, dtype)
-    length = np.array(3, dtype)
+    # Vector/Scalar (with negative)
+    test_string = [b"good", b"good", b"bad", b"good"]
+    position = np.array(-4, dtype)
+    length = np.array(1, dtype)
     substr_op = string_ops.substr(test_string, position, length)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
 
@@ -169,6 +224,16 @@
     position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
     length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
     substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      with self.assertRaises(errors_impl.InvalidArgumentError):
+        substr = substr_op.eval()
+
+    # Matrix/Matrix (with negative)
+    test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+                   [b"good", b"good", b"good"]]
+    position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
+    length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
+    substr_op = string_ops.substr(test_string, position, length)
     with self.test_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
@@ -178,6 +243,15 @@
     position = np.array([1, 2, 4], dtype)
     length = np.array([1, 2, 3], dtype)
     substr_op = string_ops.substr(test_string, position, length)
+    with self.cached_session():
+      with self.assertRaises(errors_impl.InvalidArgumentError):
+        substr = substr_op.eval()
+
+    # Broadcast (with negative)
+    test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+    position = np.array([-1, -2, -4], dtype)
+    length = np.array([1, 2, 3], dtype)
+    substr_op = string_ops.substr(test_string, position, length)
     with self.test_session():
       with self.assertRaises(errors_impl.InvalidArgumentError):
         substr = substr_op.eval()
@@ -198,7 +272,18 @@
     with self.assertRaises(ValueError):
       substr_op = string_ops.substr(test_string, position, length)
 
-  def _testAll(self, dtype):
+    # Negative position.
+    test_string = [[b"ten", b"eleven", b"twelve"],
+                   [b"thirteen", b"fourteen", b"fifteen"],
+                   [b"sixteen", b"seventeen", b"eighteen"]]
+    position = np.array([[-1, -2, -3]], dtype)
+    length = np.array([1, 2, 3], dtype)
+    # Should fail: position/length have different rank
+    with self.assertRaises(ValueError):
+      substr_op = string_ops.substr(test_string, position, length)
+
+  @parameterized.parameters(np.int32, np.int64)
+  def testAll(self, dtype):
     self._testScalarString(dtype)
     self._testVectorStrings(dtype)
     self._testMatrixStrings(dtype)
@@ -208,14 +293,8 @@
     self._testOutOfRangeError(dtype)
     self._testMismatchPosLenShapes(dtype)
 
-  def testInt32(self):
-    self._testAll(np.int32)
-
-  def testInt64(self):
-    self._testAll(np.int64)
-
   def testWrongDtype(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(TypeError):
         string_ops.substr(b"test", 3.0, 1)
       with self.assertRaises(TypeError):
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 2da7107f..0c50012 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -34,7 +34,7 @@
     return summ
 
   def testScalarSummary(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant([10.0, 20.0])
       summ = logging_ops.scalar_summary(["c1", "c2"], const, name="mysumm")
       value = sess.run(summ)
@@ -45,7 +45,7 @@
       """, self._AsSummary(value))
 
   def testScalarSummaryDefaultName(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant([10.0, 20.0])
       summ = logging_ops.scalar_summary(["c1", "c2"], const)
       value = sess.run(summ)
@@ -56,7 +56,7 @@
       """, self._AsSummary(value))
 
   def testMergeSummary(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant(10.0)
       summ1 = summary.histogram("h", const)
       summ2 = logging_ops.scalar_summary("c", const)
diff --git a/tensorflow/python/kernel_tests/summary_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
index d534aad..0f46433 100644
--- a/tensorflow/python/kernel_tests/summary_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
@@ -42,7 +42,7 @@
     self.assertTrue(np.array_equal(actual, expected))
 
   def testTags(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(1)
       s1 = summary_ops.tensor_summary("s1", c)
       with ops.name_scope("foo"):
@@ -65,7 +65,7 @@
     self.assertEqual(v4.tag, "foo/zod/TensorSummary")
 
   def testScalarSummary(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant(10.0)
       summ = summary_ops.tensor_summary("foo", const)
       result = sess.run(summ)
@@ -76,7 +76,7 @@
 
   def testStringSummary(self):
     s = six.b("foobar")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant(s)
       summ = summary_ops.tensor_summary("foo", const)
       result = sess.run(summ)
@@ -86,7 +86,7 @@
     self._AssertNumpyEq(n, s)
 
   def testManyScalarSummary(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = array_ops.ones([5, 5, 5])
       summ = summary_ops.tensor_summary("foo", const)
       result = sess.run(summ)
@@ -96,7 +96,7 @@
 
   def testManyStringSummary(self):
     strings = [[six.b("foo bar"), six.b("baz")], [six.b("zoink"), six.b("zod")]]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant(strings)
       summ = summary_ops.tensor_summary("foo", const)
       result = sess.run(summ)
@@ -106,7 +106,7 @@
 
   def testManyBools(self):
     bools = [True, True, True, False, False, False]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       const = constant_op.constant(bools)
       summ = summary_ops.tensor_summary("foo", const)
       result = sess.run(summ)
@@ -116,7 +116,7 @@
     self._AssertNumpyEq(n, bools)
 
   def testSummaryDescriptionAndDisplayName(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       def get_description(summary_op):
         summ_str = sess.run(summary_op)
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index 8ad29af..d8d7644 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -48,7 +48,7 @@
     with self.assertRaises(ValueError):
       math_ops.tensordot(a, b, (a_axes, b_axes))
     # Invalid dynamic shapes.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
                                    "Matrix size-incompatible"):
         a_ph = array_ops.placeholder(dtypes.float32)
@@ -80,7 +80,7 @@
     output = math_ops.tensordot(a_ph, b_ph, axes_ph)
     # Note: We don't support scalar Tensor values for axes.
     for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         with self.assertRaises(errors_impl.InvalidArgumentError):
           _ = sess.run(
               [output], feed_dict={
@@ -92,7 +92,7 @@
   # Test case for 11950
   def test_valid_axis(self):
     for axes_value in [1, 2], [[1], [2]], [[], []], 0:
-      with self.test_session() as sess:
+      with self.cached_session():
         np_a = np.ones((3, 3))
         np_b = np.array([2, 3, 1])[None, None]
         np_ans = np.tensordot(np_a, np_b, axes_value)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 290200c..f428002 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -451,13 +451,13 @@
         array_ops.transpose(array_ops.placeholder(dtypes.int32)).get_shape())
 
   def testNullTensor(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([], dtype=dtypes.float32, shape=[1, 4, 0])
       xt = array_ops.transpose(x, [0, 2, 1]).eval()
       self.assertAllEqual(xt.shape, (1, 0, 4))
 
   def _testError(self, x, p, err):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError(err):
         array_ops.transpose(x, p).eval()
 
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index bbc040d..316570e 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -30,7 +30,7 @@
 
   def testInt32(self):
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx = array_ops.unique(x)
       tf_y, tf_idx = sess.run([y, idx])
 
@@ -41,7 +41,7 @@
 
   def testInt32OutIdxInt64(self):
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx = array_ops.unique(x, out_idx=dtypes.int64)
       tf_y, tf_idx = sess.run([y, idx])
 
@@ -53,7 +53,7 @@
   def testString(self):
     indx = np.random.randint(65, high=122, size=7000)
     x = [chr(i) for i in indx]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx = array_ops.unique(x)
       tf_y, tf_idx = sess.run([y, idx])
 
@@ -65,7 +65,7 @@
   def testInt32Axis(self):
     for dtype in [np.int32, np.int64]:
       x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
         tf_y0, tf_idx0 = sess.run([y0, idx0])
         y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
@@ -79,7 +79,7 @@
     # This test is only temporary, once V2 is used
     # by default, the axis will be wrapped to allow `axis=None`.
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32))
       tf_y, tf_idx = sess.run([y, idx])
 
@@ -93,7 +93,7 @@
 
   def testInt32(self):
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx, count = array_ops.unique_with_counts(x)
       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
 
@@ -106,7 +106,7 @@
 
   def testInt32OutIdxInt64(self):
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64)
       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
 
@@ -121,7 +121,7 @@
     indx = np.random.randint(65, high=122, size=7000)
     x = [chr(i) for i in indx]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx, count = array_ops.unique_with_counts(x)
       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
 
@@ -136,7 +136,7 @@
   def testInt32Axis(self):
     for dtype in [np.int32, np.int64]:
       x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
             x, axis=np.array([0], dtype))
         tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
@@ -154,7 +154,7 @@
     # This test is only temporary, once V2 is used
     # by default, the axis will be wrapped to allow `axis=None`.
     x = np.random.randint(2, high=10, size=7000)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       y, idx, count = gen_array_ops.unique_with_counts_v2(
           x, axis=np.array([], np.int32))
       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 1ee6e08..b373c41 100644
--- a/tensorflow/python/kernel_tests/unstack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -99,7 +99,7 @@
           self.assertLess(err, 1e-6)
 
   def testInferNum(self):
-    with self.test_session():
+    with self.cached_session():
       for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
         x = array_ops.placeholder(np.float32, shape=shape)
         cs = array_ops.unstack(x)
@@ -131,13 +131,13 @@
       for j in range(-i, i):
         expected = np_split_squeeze(a, j)
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           actual_unstack = sess.run(array_ops.unstack(a, axis=j))
 
         self.assertAllEqual(expected, actual_unstack)
 
   def testAxis0Default(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
       unstacked = sess.run(array_ops.unstack(a))
 
@@ -156,7 +156,7 @@
       array_ops.unstack(a, axis=-3)
 
   def testZeroLengthDim(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.zeros(shape=(0, 1, 2))
       y = array_ops.unstack(x, axis=1)[0].eval()
       self.assertEqual(y.shape, (0, 2))
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
index cf369c0..3d2f8b6 100644
--- a/tensorflow/python/kernel_tests/variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -118,7 +118,7 @@
     self.assertEqual(tensor_shape.unknown_shape(), assigned.get_shape())
 
   def testAssignNoShape(self):
-    with self.test_session():
+    with self.cached_session():
       value = self._NewShapelessTensor()
       var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
       self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
@@ -126,7 +126,7 @@
                        state_ops.assign(var, value).get_shape())
 
   def testAssignNoShapeNoValidateShape(self):
-    with self.test_session():
+    with self.cached_session():
       value = self._NewShapelessTensor()
       var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
       self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index d57b79c..401e1ae 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -113,7 +113,7 @@
         self.assertEqual(w.constraint, constraint)
 
   def testStringDefaultInitializer(self):
-    with self.test_session():
+    with self.cached_session():
       v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
       variables_lib.global_variables_initializer().run()
       self.assertAllEqual(compat.as_bytes(v.eval()), b"")
@@ -263,7 +263,7 @@
 
   # TODO(alive): support variable partitioning/caching in eager mode.
   def testVarScopeCachingDevice(self):
-    with self.test_session():
+    with self.cached_session():
       caching_device = "/job:moo"
       with variable_scope.variable_scope("tower"):
         with variable_scope.variable_scope(
@@ -367,7 +367,7 @@
       variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
 
   def testControlDeps(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variable_scope.get_variable(
           "v0", [1], initializer=init_ops.constant_initializer(0))
       with ops.control_dependencies([v0.value()]):
@@ -403,7 +403,7 @@
       variable_scope._DEFAULT_USE_RESOURCE = old
 
   def testControlFlow(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variable_scope.get_variable(
           "v0", [], initializer=init_ops.constant_initializer(0))
       var_dict = {}
@@ -513,7 +513,7 @@
           self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
 
   def testVarScopeOriginalNameScope(self):
-    with self.test_session():
+    with self.cached_session():
       with ops.name_scope("scope1"):
         with variable_scope.variable_scope("tower") as tower:
           self.assertEqual(tower.original_name_scope, "scope1/tower/")
@@ -536,7 +536,7 @@
               self.assertEqual(sc3, "scope1/tower/bar_1/")
 
   def testVarScopeObjectReuse(self):
-    with self.test_session():
+    with self.cached_session():
       vs = None
       with variable_scope.variable_scope("jump", reuse=True) as scope:
         vs = scope
@@ -563,7 +563,7 @@
         self.assertFalse(jump_no_reuse.reuse)
 
   def testVarScopeGetOrCreateReuse(self):
-    with self.test_session():
+    with self.cached_session():
 
       def test_value(value):
         x = constant_op.constant(value)
@@ -582,7 +582,7 @@
       test_value(17.)
 
   def testVarOpScope(self):
-    with self.test_session():
+    with self.cached_session():
       with ops.name_scope("testVarOpScope1"):
         with variable_scope.variable_scope("tower", "default", []):
           self.assertEqual(
@@ -608,7 +608,7 @@
             self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
 
   def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope(None, "defaultScope1"):
         with variable_scope.variable_scope(None, "layer"):
           self.assertEqual(
@@ -631,7 +631,7 @@
               "defaultScope1_2/layer/w:0")
 
   def testVarOpScopeUniqueNamesWithJump(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("default") as default:
         with variable_scope.variable_scope(None, "layer"):
           self.assertEqual(
@@ -647,7 +647,7 @@
               variable_scope.get_variable("w", []).name, "default/layer_2/w:0")
 
   def testVarOpScopeReuse(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         with variable_scope.variable_scope("tower", "default", []):
           self.assertEqual(
@@ -673,7 +673,7 @@
             self.assertEqual(sc2, "outer_1/default/scope2/")
 
   def testVarScopeGetVar(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("root"):
         with variable_scope.variable_scope("towerA") as tower_a:
           va = variable_scope.get_variable("v", [1])
@@ -719,7 +719,7 @@
         self.assertEqual("dtype" in str(exc.exception), True)
 
   def testVarScopeOuterScope(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         pass
       with variable_scope.variable_scope(outer):
@@ -743,7 +743,7 @@
             self.assertEqual(sc2, "outer_2/default/scope2/")
 
   def testVarScopeNestedOuterScope(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         with variable_scope.variable_scope(outer):
           self.assertEqual(
@@ -768,7 +768,7 @@
             self.assertEqual(sc2, "outer/default_1/scope2/")
 
   def testVarOpScopeReuseParam(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         with variable_scope.variable_scope("tower", "default", []):
           self.assertEqual(
@@ -795,14 +795,14 @@
             self.assertEqual(sc2, "outer_1/default/scope2/")
 
   def testVarOpScopeReuseError(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         with variable_scope.variable_scope(None, "default", reuse=True):
           self.assertEqual(
               variable_scope.get_variable("w", []).name, "outer/tower/w:0")
 
   def testVarOpScopeOuterScope(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         pass
       with variable_scope.variable_scope(outer, "default", []):
@@ -827,7 +827,7 @@
             self.assertEqual(sc2, "outer_2/default/scope2/")
 
   def testVarOpScopeNestedOuterScope(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer") as outer:
         with variable_scope.variable_scope(outer, "default", []):
           self.assertEqual(
@@ -851,7 +851,7 @@
             self.assertEqual(sc2, "outer_1/default/scope2/")
 
   def testBasicWhenAuxiliaryNameScopeIsFalse(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope(
           "scope", auxiliary_name_scope=False) as scope:
         self.assertEqual(scope.original_name_scope, "")
@@ -886,7 +886,7 @@
               constant_op.constant([], name="c").name, "outer/inner/c:0")
 
   def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope(
           None, default_name="default", auxiliary_name_scope=False) as scope:
         self.assertEqual(scope.original_name_scope, "")
@@ -910,7 +910,7 @@
               constant_op.constant([], name="c").name, "outer/default/c:0")
 
   def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
-    with self.test_session():
+    with self.cached_session():
       root_scope = variable_scope.get_variable_scope()
       with variable_scope.variable_scope(
           root_scope, auxiliary_name_scope=False) as scope:
@@ -927,7 +927,7 @@
               constant_op.constant([], name="c1").name, "outer/c1:0")
 
   def testAuxiliaryNameScopeIsInvalid(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
         with variable_scope.variable_scope(
             None, default_name="scope", auxiliary_name_scope="invalid"):
@@ -947,7 +947,7 @@
 
   def testReuseScopeWithoutNameScopeCollision(self):
     # Github issue: #13429
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("outer"):
         with variable_scope.variable_scope("inner") as inner:
           pass
@@ -1021,7 +1021,7 @@
     self.assertEqual(varname_type[1], ("y", dtypes.int64))
 
   def testGetCollection(self):
-    with self.test_session():
+    with self.cached_session():
       _ = variable_scope.get_variable("testGetCollection_a", [])
       _ = variable_scope.get_variable(
           "testGetCollection_b", [], trainable=False)
@@ -1075,7 +1075,7 @@
       ])
 
   def testGetTrainableVariablesWithGetVariable(self):
-    with self.test_session():
+    with self.cached_session():
       _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
       with variable_scope.variable_scope(
           "testGetTrainableVariables_foo") as scope:
@@ -1111,7 +1111,7 @@
             trainable=True)
 
   def testGetTrainableVariablesWithVariable(self):
-    with self.test_session():
+    with self.cached_session():
       _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
       with variable_scope.variable_scope(
           "testGetTrainableVariables_foo") as scope:
@@ -1150,7 +1150,7 @@
             trainable=True)
 
   def testGetGlobalVariables(self):
-    with self.test_session():
+    with self.cached_session():
       _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
       with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
         _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
@@ -1160,7 +1160,7 @@
              "testGetGlobalVariables_b:0"])
 
   def testGetLocalVariables(self):
-    with self.test_session():
+    with self.cached_session():
       _ = variable_scope.get_variable(
           "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
       with variable_scope.variable_scope("foo") as scope:
@@ -1396,7 +1396,7 @@
     self.assertEqual("scope/v/0:0", true_vars[0].name)
     self.assertEqual("scope/v/1:0", true_vars[1].name)
     self.assertEqual("custom_getter/add:0", v.name)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       np_vars, np_v = sess.run([true_vars, v])
       self.assertAllClose(np_v, sum(np_vars))
@@ -1436,7 +1436,7 @@
     self.assertEqual(template % (1, 1, 0), true_vars[6].name)
     self.assertEqual(template % (1, 1, 1), true_vars[7].name)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       variables_lib.global_variables_initializer().run()
       np_vars, np_v = sess.run([true_vars, v])
       # take products of sums of products
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2b9c62a..2e79756 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -42,7 +42,7 @@
 class VariablesTestCase(test.TestCase):
 
   def testInitialization(self):
-    with self.test_session():
+    with self.cached_session():
       var0 = variables.Variable(0.0)
       self.assertEqual("Variable:0", var0.name)
       self.assertEqual("Variable", var0._shared_name)
@@ -69,7 +69,7 @@
       self.assertAllClose(1.1, var1.eval())
 
   def testInitializationOrder(self):
-    with self.test_session():
+    with self.cached_session():
       rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
       self.assertEqual("rnd:0", rnd.name)
       self.assertEqual([3, 6], rnd.get_shape())
@@ -106,7 +106,7 @@
         pass
 
   def testAssignments(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(0.0)
       plus_one = var.assign_add(1.0)
       minus_one = var.assign_sub(2.0)
@@ -142,7 +142,7 @@
       self.assertAllClose(4.0, var.eval())
 
   def testZeroSizeStringAssign(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       array = variables.Variable(
           initial_value=array_ops.zeros((0,), dtype=dtypes.string),
           name="foo",
@@ -154,7 +154,7 @@
       self.assertEqual([], list(sess.run(copy_op)))
 
   def _countUpToTest(self, dtype):
-    with self.test_session():
+    with self.cached_session():
       zero = constant_op.constant(0, dtype=dtype)
       var = variables.Variable(zero)
       count_up_to = var.count_up_to(3)
@@ -186,7 +186,7 @@
     self._countUpToTest(dtypes.int64)
 
   def testControlDepsNone(self):
-    with self.test_session():
+    with self.cached_session():
       c = constant_op.constant(1.0)
       with ops.control_dependencies([c]):
         # d get the control dep.
@@ -199,7 +199,7 @@
       self.assertEqual([], var_x._ref().op.control_inputs)  # pylint: disable=protected-access
 
   def testControlFlow(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variables.Variable(0, name="v0")
       var_dict = {}
 
@@ -248,7 +248,7 @@
       control_flow_ops.while_loop(cond, body, [0, 0])
 
   def testUseVariableAsTensor(self):
-    with self.test_session():
+    with self.cached_session():
       var_x = variables.Variable(2.0)
       var_y = variables.Variable(3.0)
       variables.global_variables_initializer().run()
@@ -257,7 +257,7 @@
       self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval())
 
   def testZeroSizeVarSameAsConst(self):
-    with self.test_session():
+    with self.cached_session():
       zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
       zero_size_const = array_ops.ones([2, 0])
       variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
@@ -269,7 +269,7 @@
       self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
 
   def testCachingDevice(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(2.0)
       self.assertEqual(var.device, var.value().device)
       self.assertEqual(var.device, var.initialized_value().device)
@@ -279,7 +279,7 @@
       self.assertTrue(var_cached.value().device.startswith("/job:foo"))
 
   def testCollections(self):
-    with self.test_session():
+    with self.cached_session():
       var_x = variables.Variable(2.0)
       var_y = variables.Variable(2.0, trainable=False)
       var_z = variables.Variable(2.0, trainable=True)
@@ -294,7 +294,7 @@
       self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
 
   def testCollectionsWithScope(self):
-    with self.test_session():
+    with self.cached_session():
       with ops.name_scope("scope_1"):
         var_x = variables.Variable(2.0)
       with ops.name_scope("scope_2"):
@@ -309,7 +309,7 @@
       self.assertEqual([var_y], variables.trainable_variables("scope_2"))
 
   def testOperators(self):
-    with self.test_session():
+    with self.cached_session():
       var_f = variables.Variable([2.0])
       add = var_f + 0.0
       radd = 1.0 + var_f
@@ -382,13 +382,13 @@
       self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval())
 
   def testSession(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var = variables.Variable([1, 12])
       variables.global_variables_initializer().run()
       self.assertAllClose([1, 12], sess.run(var))
 
   def testDevicePlacement(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with ops.device("/cpu:0"):
         var = variables.Variable([1, 12])
       init_value = var.initialized_value()
@@ -408,7 +408,7 @@
   def testInitializerFunction(self):
     value = [[-42], [133.7]]
     shape = [2, 1]
-    with self.test_session():
+    with self.cached_session():
       initializer = lambda: constant_op.constant(value)
 
       v1 = variables.Variable(initializer, dtype=dtypes.float32)
@@ -443,7 +443,7 @@
           constraint=constraint)
 
   def testNoRefDataRace(self):
-    with self.test_session():
+    with self.cached_session():
       a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
       b = variables.Variable(a.initialized_value() + 2)
       c = variables.Variable(b.initialized_value() + 2)
@@ -453,7 +453,7 @@
       self.assertAllEqual(c.eval(), [5, 6, 7])
 
   def testInitializerFunctionDevicePlacement(self):
-    with self.test_session():
+    with self.cached_session():
       initializer = lambda: constant_op.constant(42.0)
       with ops.device("/cpu:100"):
         v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
@@ -471,11 +471,11 @@
         self.assertEqual(expected_group_v2, i.op.colocation_groups())
 
   def testVariableDefInitializedInstances(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v_def = variables.Variable(
           initial_value=constant_op.constant(3.0)).to_proto()
 
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       # v describes a VariableDef-based variable without an initial value.
       v = variables.Variable(variable_def=v_def)
       self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -486,7 +486,7 @@
       self.assertEqual(1.0, v.initialized_value().eval())
 
     v_def.ClearField("initial_value_name")
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       # Restoring a legacy VariableDef proto that does not have
       # initial_value_name set should still work.
       v = variables.Variable(variable_def=v_def)
@@ -514,7 +514,7 @@
           .trainable)
 
   def testLoad(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(np.zeros((5, 5), np.float32))
       variables.global_variables_initializer().run()
       var.load(np.ones((5, 5), np.float32))
@@ -540,12 +540,12 @@
 class IsInitializedTest(test.TestCase):
 
   def testNoVars(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       uninited = variables.report_uninitialized_variables()
       self.assertEqual(0, sess.run(uninited).size)
 
   def testAssertVariablesInitialized(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([1, 2], name="v")
       w = variables.Variable([3, 4], name="w")
       _ = v, w
@@ -555,7 +555,7 @@
       self.assertEqual(0, sess.run(uninited).size)
 
   def testVariableList(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([1, 2], name="v")
       w = variables.Variable([3, 4], name="w")
       uninited = variables.report_uninitialized_variables()
@@ -566,14 +566,14 @@
       self.assertEqual(0, sess.run(uninited).size)
 
   def testZeroSizeVarInitialized(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable(array_ops.zeros([0, 2]), name="v")
       uninited = variables.report_uninitialized_variables()
       v.initializer.run()  # not strictly necessary
       self.assertEqual(0, sess.run(uninited).size)
 
   def testTrainingWithZeroSizeVar(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       a = variables.Variable(array_ops.zeros([0, 2]))
       b = variables.Variable(array_ops.ones([2, 2]))
       objective = math_ops.reduce_sum(b + math_ops.matmul(
@@ -592,7 +592,7 @@
       self.assertEqual(None, variables.assert_variables_initialized())
 
   def testVariables(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([1, 2])
       w = variables.Variable([3, 4])
       _ = v, w
@@ -603,7 +603,7 @@
       sess.run(inited)
 
   def testVariableList(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([1, 2])
       w = variables.Variable([3, 4])
       inited = variables.assert_variables_initialized([v])
diff --git a/tensorflow/python/kernel_tests/weights_broadcast_test.py b/tensorflow/python/kernel_tests/weights_broadcast_test.py
index eda2856e..85f9abc 100644
--- a/tensorflow/python/kernel_tests/weights_broadcast_test.py
+++ b/tensorflow/python/kernel_tests/weights_broadcast_test.py
@@ -44,7 +44,7 @@
     values_placeholder = array_ops.placeholder(dtypes_lib.float32)
     dynamic_op = weights_broadcast_ops.assert_broadcastable(
         weights=weights_placeholder, values=values_placeholder)
-    with self.test_session():
+    with self.cached_session():
       static_op.run()
       dynamic_op.run(feed_dict={
           weights_placeholder: weights,
@@ -100,7 +100,7 @@
     values_placeholder = array_ops.placeholder(dtypes_lib.float32)
     dynamic_op = weights_broadcast_ops.assert_broadcastable(
         weights=weights_placeholder, values=values_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
         dynamic_op.run(feed_dict={
             weights_placeholder: weights,
@@ -157,7 +157,7 @@
     values_placeholder = array_ops.placeholder(dtypes_lib.float32)
     dynamic_op = weights_broadcast_ops.broadcast_weights(
         weights=weights_placeholder, values=values_placeholder)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(expected, static_op.eval())
       self.assertAllEqual(expected, dynamic_op.eval(feed_dict={
           weights_placeholder: weights,
@@ -227,7 +227,7 @@
     values_placeholder = array_ops.placeholder(dtypes_lib.float32)
     dynamic_op = weights_broadcast_ops.broadcast_weights(
         weights=weights_placeholder, values=values_placeholder)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
         dynamic_op.eval(feed_dict={
             weights_placeholder: weights,
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 60c726d..7298851 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -153,13 +153,13 @@
       self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
 
   def testShapeMismatch(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         gen_nn_ops.softmax_cross_entropy_with_logits(
             [[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
 
   def testNotMatrix(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         gen_nn_ops.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
                                                      [0., 1., 0., 1.])
@@ -180,7 +180,7 @@
         np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
 
   def testGradient(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       l = constant_op.constant(
           [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
           shape=[3, 4],
@@ -207,7 +207,7 @@
     self.assertLess(err, 5e-8)
 
   def testGradientLabelWithV2(self):
-    with self.test_session():
+    with self.cached_session():
       l = constant_op.constant(
           [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
           shape=[3, 4],
@@ -225,7 +225,7 @@
     self.assertLess(err, 5e-8)
 
   def testSecondGradient(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       l = constant_op.constant(
           [
               0.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, 0.0, 0.0, 0.0, 0.0, 0.5 / 3, 0.0,
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 3b4f12a..269142a 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -55,6 +55,10 @@
   return PyIsInstance(obj, &PyDoubleArrType_Type);  // NumPy double type.
 }
 
+bool IsNumpyHalf(PyObject* obj) {
+  return PyIsInstance(obj, &PyHalfArrType_Type);
+}
+
 bool IsPyFloat(PyObject* obj) {
   return PyFloat_Check(obj) ||
          PyIsInstance(obj, &PyFloatingArrType_Type);  // NumPy float types
@@ -156,6 +160,8 @@
       }
     } else if (IsPyDouble(obj)) {
       *dtype = DT_DOUBLE;
+    } else if (IsNumpyHalf(obj)) {
+      *dtype = DT_HALF;
     } else if (IsPyFloat(obj)) {
       *dtype = DT_FLOAT;
     } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) {
@@ -357,6 +363,17 @@
 DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
 DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
 
+const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
+  // NOTE(nareshmodi): Is there a way to convert to C double without the
+  // intermediate Python double? This will help with ConvertOneFloat as well.
+  Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
+  double v_double = PyFloat_AS_DOUBLE(as_float.get());
+  *out = Eigen::half(v_double);
+
+  return nullptr;
+}
+DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf);
+
 // String support
 
 const char* ConvertOneString(PyObject* v, string* out) {
@@ -452,6 +469,9 @@
       if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
       break;
 
+    case DT_HALF:
+      RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
     case DT_INT64:
       if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
       break;
@@ -489,8 +509,13 @@
         // final type.
         RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
       }
+
     case DT_DOUBLE:
       RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
+
+    case DT_HALF:
+      RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
+
     case DT_INT64:
       if (requested_dtype == DT_INVALID) {
         const char* error = ConvertInt32(obj, shape, ret);
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 9500fc6..07ce071 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -30,6 +30,8 @@
 
 PyRecordReader::PyRecordReader() {}
 
+// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
+// RecordReaderOptions, if this changes the API can be updated at that time.
 PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
                                     const string& compression_type_string,
                                     TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index e4e5268..faf20df 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -28,7 +28,7 @@
 PyRecordWriter::PyRecordWriter() {}
 
 PyRecordWriter* PyRecordWriter::New(const string& filename,
-                                    const string& compression_type_string,
+                                    const io::RecordWriterOptions& options,
                                     TF_Status* out_status) {
   std::unique_ptr<WritableFile> file;
   Status s = Env::Default()->NewWritableFile(filename, &file);
@@ -38,10 +38,6 @@
   }
   PyRecordWriter* writer = new PyRecordWriter;
   writer->file_ = std::move(file);
-
-  RecordWriterOptions options =
-      RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
-
   writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
   return writer;
 }
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index 61a4960..9b0792c 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -20,6 +20,7 @@
 
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/record_writer.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
 
@@ -36,10 +37,8 @@
 // by multiple threads.
 class PyRecordWriter {
  public:
-  // TODO(vrv): make this take a shared proto to configure
-  // the compression options.
   static PyRecordWriter* New(const string& filename,
-                             const string& compression_type_string,
+                             const io::RecordWriterOptions& compression_options,
                              TF_Status* out_status);
   ~PyRecordWriter();
 
diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i
index 3181c9a..b2c2bda 100644
--- a/tensorflow/python/lib/io/py_record_writer.i
+++ b/tensorflow/python/lib/io/py_record_writer.i
@@ -18,6 +18,11 @@
 %include "tensorflow/python/platform/base.i"
 %include "tensorflow/python/lib/core/strings.i"
 
+// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
+// and "stdint.i" disagree on the definition of int64_t.
+typedef signed char int8;
+%{ typedef signed char int8; %}
+
 %feature("except") tensorflow::io::PyRecordWriter::New {
   // Let other threads run while we write
   Py_BEGIN_ALLOW_THREADS
@@ -26,6 +31,7 @@
 }
 
 %newobject tensorflow::io::PyRecordWriter::New;
+%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
 
 %feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
   // Let other threads run while we write
@@ -35,6 +41,8 @@
 }
 
 %{
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
 #include "tensorflow/python/lib/io/py_record_writer.h"
 %}
 
@@ -48,7 +56,21 @@
 %unignore tensorflow::io::PyRecordWriter::Flush;
 %unignore tensorflow::io::PyRecordWriter::Close;
 %unignore tensorflow::io::PyRecordWriter::New;
+%unignore tensorflow::io::ZlibCompressionOptions;
+%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
+%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
+%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
+%unignore tensorflow::io::RecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::zlib_options;
 
+%include "tensorflow/core/lib/io/record_writer.h"
+%include "tensorflow/core/lib/io/zlib_compression_options.h"
 %include "tensorflow/python/lib/io/py_record_writer.h"
 
 %unignoreall
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 2b3e986..cce71a2 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -33,8 +33,6 @@
   GZIP = 2
 
 
-# NOTE(vrv): This will eventually be converted into a proto.  to match
-# the interface used by the C++ RecordWriter.
 @tf_export("python_io.TFRecordOptions")
 class TFRecordOptions(object):
   """Options used for manipulating TFRecord files."""
@@ -44,14 +42,105 @@
       TFRecordCompressionType.NONE: ""
   }
 
-  def __init__(self, compression_type):
+  def __init__(self,
+               compression_type=None,
+               flush_mode=None,
+               input_buffer_size=None,
+               output_buffer_size=None,
+               window_bits=None,
+               compression_level=None,
+               compression_method=None,
+               mem_level=None,
+               compression_strategy=None):
+    # pylint: disable=line-too-long
+    """Creates a `TFRecordOptions` instance.
+
+    Options only effect TFRecordWriter when compression_type is not `None`.
+    Documentation, details, and defaults can be found in
+    [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
+    and in the [zlib manual](http://www.zlib.net/manual.html).
+    Leaving an option as `None` allows C++ to set a reasonable default.
+
+    Args:
+      compression_type: `TFRecordCompressionType` or `None`.
+      flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
+      input_buffer_size: int or `None`.
+      output_buffer_size: int or `None`.
+      window_bits: int or `None`.
+      compression_level: 0 to 9, or `None`.
+      compression_method: compression method or `None`.
+      mem_level: 1 to 9, or `None`.
+      compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
+
+    Returns:
+      A `TFRecordOptions` object.
+
+    Raises:
+      ValueError: If compression_type is invalid.
+    """
+    # pylint: enable=line-too-long
+    # Check compression_type is valid, but for backwards compatibility don't
+    # immediately convert to a string.
+    self.get_compression_type_string(compression_type)
     self.compression_type = compression_type
+    self.flush_mode = flush_mode
+    self.input_buffer_size = input_buffer_size
+    self.output_buffer_size = output_buffer_size
+    self.window_bits = window_bits
+    self.compression_level = compression_level
+    self.compression_method = compression_method
+    self.mem_level = mem_level
+    self.compression_strategy = compression_strategy
 
   @classmethod
   def get_compression_type_string(cls, options):
+    """Convert various option types to a unified string.
+
+    Args:
+      options: `TFRecordOption`, `TFRecordCompressionType`, or string.
+
+    Returns:
+      Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
+
+    Raises:
+      ValueError: If compression_type is invalid.
+    """
     if not options:
       return ""
-    return cls.compression_type_map[options.compression_type]
+    elif isinstance(options, TFRecordOptions):
+      return cls.get_compression_type_string(options.compression_type)
+    elif isinstance(options, TFRecordCompressionType):
+      return cls.compression_type_map[options]
+    elif options in TFRecordOptions.compression_type_map:
+      return cls.compression_type_map[options]
+    elif options in TFRecordOptions.compression_type_map.values():
+      return options
+    else:
+      raise ValueError('Not a valid compression_type: "{}"'.format(options))
+
+  def _as_record_writer_options(self):
+    """Convert to RecordWriterOptions for use with PyRecordWriter."""
+    options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
+        compat.as_bytes(
+            self.get_compression_type_string(self.compression_type)))
+
+    if self.flush_mode is not None:
+      options.zlib_options.flush_mode = self.flush_mode
+    if self.input_buffer_size is not None:
+      options.zlib_options.input_buffer_size = self.input_buffer_size
+    if self.output_buffer_size is not None:
+      options.zlib_options.output_buffer_size = self.output_buffer_size
+    if self.window_bits is not None:
+      options.zlib_options.window_bits = self.window_bits
+    if self.compression_level is not None:
+      options.zlib_options.compression_level = self.compression_level
+    if self.compression_method is not None:
+      options.zlib_options.compression_method = self.compression_method
+    if self.mem_level is not None:
+      options.zlib_options.mem_level = self.mem_level
+    if self.compression_strategy is not None:
+      options.zlib_options.compression_strategy = self.compression_strategy
+    return options
 
 
 @tf_export("python_io.tf_record_iterator")
@@ -100,16 +189,21 @@
 
     Args:
       path: The path to the TFRecords file.
-      options: (optional) A TFRecordOptions object.
+      options: (optional) String specifying compression type,
+          `TFRecordCompressionType`, or `TFRecordOptions` object.
 
     Raises:
       IOError: If `path` cannot be opened for writing.
+      ValueError: If valid compression_type can't be determined from `options`.
     """
-    compression_type = TFRecordOptions.get_compression_type_string(options)
+    if not isinstance(options, TFRecordOptions):
+      options = TFRecordOptions(compression_type=options)
 
     with errors.raise_exception_on_not_ok_status() as status:
+      # pylint: disable=protected-access
       self._writer = pywrap_tensorflow.PyRecordWriter_New(
-          compat.as_bytes(path), compat.as_bytes(compression_type), status)
+          compat.as_bytes(path), options._as_record_writer_options(), status)
+      # pylint: enable=protected-access
 
   def __enter__(self):
     """Enter a `with` block."""
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index b853b64..def8fe2 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -20,6 +20,8 @@
 
 import gzip
 import os
+import random
+import string
 import zlib
 
 import six
@@ -131,9 +133,6 @@
 
 class TFRecordWriterTest(TFCompressionTestCase):
 
-  def setUp(self):
-    super(TFRecordWriterTest, self).setUp()
-
   def _AssertFilesEqual(self, a, b, equal):
     for an, bn in zip(a, b):
       with open(an, "rb") as af, open(bn, "rb") as bf:
@@ -142,6 +141,37 @@
         else:
           self.assertNotEqual(af.read(), bf.read())
 
+  def _CompressionSizeDelta(self, records, options_a, options_b):
+    """Validate compression with options_a and options_b and return size delta.
+
+    Compress records with options_a and options_b. Uncompress both compressed
+    files and assert that the contents match the original records. Finally
+    calculate how much smaller the file compressed with options_a was than the
+    file compressed with options_b.
+
+    Args:
+      records: The records to compress
+      options_a: First set of options to compress with, the baseline for size.
+      options_b: Second set of options to compress with.
+
+    Returns:
+      The difference in file size when using options_a vs options_b. A positive
+      value means options_a was a better compression than options_b. A negative
+      value means options_b had better compression than options_a.
+
+    """
+
+    fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a)
+    test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a))
+    self.assertEqual(records, test_a, options_a)
+
+    fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b)
+    test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b))
+    self.assertEqual(records, test_b, options_b)
+
+    # Negative number => better compression.
+    return os.path.getsize(fn_a) - os.path.getsize(fn_b)
+
   def testWriteReadZLibFiles(self):
     # Write uncompressed then compress manually.
     options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
@@ -188,6 +218,76 @@
     ]
     self._AssertFilesEqual(uncompressed_files, files, True)
 
+  def testNoCompressionType(self):
+    self.assertEqual(
+        "",
+        tf_record.TFRecordOptions.get_compression_type_string(
+            tf_record.TFRecordOptions()))
+
+    self.assertEqual(
+        "",
+        tf_record.TFRecordOptions.get_compression_type_string(
+            tf_record.TFRecordOptions("")))
+
+    with self.assertRaises(ValueError):
+      tf_record.TFRecordOptions(5)
+
+    with self.assertRaises(ValueError):
+      tf_record.TFRecordOptions("BZ2")
+
+  def testZlibCompressionType(self):
+    zlib_t = tf_record.TFRecordCompressionType.ZLIB
+
+    self.assertEqual(
+        "ZLIB",
+        tf_record.TFRecordOptions.get_compression_type_string(
+            tf_record.TFRecordOptions("ZLIB")))
+
+    self.assertEqual(
+        "ZLIB",
+        tf_record.TFRecordOptions.get_compression_type_string(
+            tf_record.TFRecordOptions(zlib_t)))
+
+    self.assertEqual(
+        "ZLIB",
+        tf_record.TFRecordOptions.get_compression_type_string(
+            tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t))))
+
+  def testCompressionOptions(self):
+    # Create record with mix of random and repeated data to test compression on.
+    rnd = random.Random(123)
+    random_record = compat.as_bytes(
+        "".join(rnd.choice(string.digits) for _ in range(10000)))
+    repeated_record = compat.as_bytes(_TEXT)
+    for _ in range(10000):
+      start_i = rnd.randint(0, len(_TEXT))
+      length = rnd.randint(10, 200)
+      repeated_record += _TEXT[start_i:start_i + length]
+    records = [random_record, repeated_record, random_record]
+
+    tests = [
+        ("compression_level", 2, -1),  # Lower compression is worse.
+        ("compression_level", 6, 0),  # Default compression_level is equal.
+        ("flush_mode", zlib.Z_FULL_FLUSH, 1),  # A few less bytes.
+        ("flush_mode", zlib.Z_NO_FLUSH, 0),  # NO_FLUSH is the default.
+        ("input_buffer_size", 4096, 0),  # Increases time not size.
+        ("output_buffer_size", 4096, 0),  # Increases time not size.
+        ("window_bits", 8, -1),  # Smaller than default window increases size.
+        ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1),  # Worse.
+        ("compression_strategy", zlib.Z_FILTERED, -1),  # Worse.
+    ]
+
+    compression_type = tf_record.TFRecordCompressionType.ZLIB
+    options_a = tf_record.TFRecordOptions(compression_type)
+    for prop, value, delta_sign in tests:
+      options_b = tf_record.TFRecordOptions(
+          compression_type=compression_type, **{prop: value})
+      delta = self._CompressionSizeDelta(records, options_a, options_b)
+      self.assertTrue(
+          delta == 0 if delta_sign == 0 else delta // delta_sign > 0,
+          "Setting {} = {}, file was {} smaller didn't match sign of {}".format(
+              prop, value, delta, delta_sign))
+
 
 class TFRecordWriterZlibTest(TFCompressionTestCase):
 
@@ -318,6 +418,7 @@
       for _ in tf_record.tf_record_iterator(fn_truncated):
         pass
 
+
 class TFRecordWriterCloseAndFlushTests(test.TestCase):
 
   def setUp(self, compression_type=TFRecordCompressionType.NONE):
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 6ae869b..ade86e8 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -805,3 +805,22 @@
   indices = op.inputs[1]
   updates_grad = array_ops.gather_nd(grad, indices)
   return [grad, None, updates_grad]
+
+
+@ops.RegisterGradient("BroadcastTo")
+def _BroadcastToGrad(op, grad):
+  input_value = op.inputs[0]
+  broadcast_shape = op.inputs[1]
+  # Assign ids for each position in input_value.
+  input_value_shape = array_ops.shape(input_value)
+  input_value_size = array_ops.size(input_value)
+  ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape)
+  broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape)
+  # Group by ids and sum its gradients.
+  grad_flatten = array_ops.reshape(grad, [-1])
+  broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1])
+  updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten,
+                                                       broadcast_ids_flatten,
+                                                       input_value_size)
+  updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape)
+  return [updates_grad, None]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 21ccbc6..c8b8833 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1275,7 +1275,7 @@
 def split(value, num_or_size_splits, axis=0, num=None, name="split"):
   """Splits a tensor into sub tensors.
 
-  If `num_or_size_splits` is an integer type, `num_split`, then splits `value`
+  If `num_or_size_splits` is an integer type, then `value` is split
   along dimension `axis` into `num_split` smaller tensors.
   Requires that `num_split` evenly divides `value.shape[axis]`.
 
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index f7cbfe0..720f9f4 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -24,11 +24,17 @@
 
 # Re-exporting ops used by other modules.
 # pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
 # pylint: enable=unused-import
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 6528062..c3cf6e6 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -1292,3 +1292,9 @@
     shape = tensor_shape.TensorShape(shape)
 
   return array_ops.ensure_shape(x, shape, name=name)
+
+
+@ops.RegisterGradient('EnsureShape')
+def _ensure_shape_grad(op, grad):
+  del op  # Unused.
+  return grad
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 78b395a..2946843 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -144,7 +144,11 @@
     t = ops.convert_to_tensor(t, name="t")
 
     # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
-    l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True))
+    l2sum = math_ops.reduce_sum(t * t, axes, keepdims=True)
+    pred = l2sum > 0
+    # Two-tap tf.where trick to bypass NaN gradients
+    l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum))
+    l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum)
     intermediate = t * clip_norm
     # Assert that the shape is compatible with the initial shape,
     # to prevent unintentional broadcasting.
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c4e9c98..c6a6b2a 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -180,16 +180,16 @@
 
 
 def _get_func_graphs(if_op):
-  """Returns `_FuncGraph`s for the input op branches.
+  """Returns `FuncGraph`s for the input op branches.
 
   Args:
     if_op: The _If Operation.
 
   Returns:
-    A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch.
+    A 2-tuple of the `FuncGraph`s of the then_branch and else_branch.
   """
   def _get_func_graph_for_branch(branch_name):
-    """Generates and returns a _FuncGraph for the given branch."""
+    """Generates and returns a FuncGraph for the given branch."""
     inputs = if_op.inputs[1:]  # First input is pred.
     input_shapes = [t.shape for t in inputs]
     func_name = if_op.get_attr(branch_name).name
@@ -197,7 +197,7 @@
     # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
     # in the case of nested if ops or when the gradient is being computed
     # from inside a Defun. We build the `func_graph` with `if_op.graph` as its
-    # `outer_graph`. This resembles how the `_FuncGraph` was built in the
+    # `outer_graph`. This resembles how the `FuncGraph` was built in the
     # forward pass. We need this so that we can resolve references to tensors
     # in `func_graph` from its gradient graph in `_resolve_grad_inputs`.
     with if_op.graph.as_default():
@@ -221,7 +221,7 @@
   func_graph's outputs w.r.t. its inputs.
 
   Args:
-    func_graph: function._FuncGraph. The corresponding forward-pass function.
+    func_graph: function.FuncGraph. The corresponding forward-pass function.
     grads: The list of input gradient Tensors.
 
   Returns:
@@ -259,7 +259,7 @@
 
 
 def _create_grad_func(func_graph, grads, name):
-  """Returns the _FuncGraph representation of _grad_fn."""
+  """Returns the FuncGraph representation of _grad_fn."""
   return _function.func_graph_from_py_func(
       name, lambda: _grad_fn(func_graph, grads), [], {})
 
@@ -277,8 +277,8 @@
      functions, this is always possible.
 
   Args:
-    cond_graph: function._FuncGraph. The forward-pass function.
-    grad_graph: function._FuncGraph. The gradients function.
+    cond_graph: function.FuncGraph. The forward-pass function.
+    grad_graph: function.FuncGraph. The gradients function.
 
   Returns:
     A list of inputs tensors to be passed to grad_graph.
@@ -313,7 +313,7 @@
   """Converts func_graph to a TF_Function and adds it to the current graph.
 
   Args:
-    func_graph: function._FuncGraph
+    func_graph: function.FuncGraph
 
   Returns:
     The name of the new TF_Function.
@@ -365,8 +365,8 @@
   There is no merging of params.
 
   Args:
-    true_graph: function._FuncGraph
-    false_graph: function._FuncGraph
+    true_graph: function.FuncGraph
+    false_graph: function.FuncGraph
     true_params: a list of Tensors from true_graph
     false_params: a list of Tensors from false_graph
 
@@ -391,8 +391,8 @@
   graph to avoid duplicating shared arguments.
 
   Args:
-    true_graph: function._FuncGraph
-    false_graph: function._FuncGraph
+    true_graph: function.FuncGraph
+    false_graph: function.FuncGraph
     true_inputs: a list of Tensors in the outer graph. The inputs for
       true_graph.
     false_inputs: a list of Tensors in the outer graph. The inputs for
@@ -421,7 +421,7 @@
       _create_dummy_params(false_graph, true_only_inputs) +
       [false_input_to_param[t] for t in false_only_inputs])
 
-  # Rewrite the _FuncGraphs' state to reflect the new inputs.
+  # Rewrite the FuncGraphs' state to reflect the new inputs.
   true_graph.captures = collections.OrderedDict(zip(new_inputs,
                                                     true_graph.inputs))
   false_graph.captures = collections.OrderedDict(zip(new_inputs,
@@ -434,7 +434,7 @@
   """Creates tensors in func_graph to represent template_tensors.
 
   Args:
-    func_graph: function._FuncGraph.
+    func_graph: function.FuncGraph.
     template_tensors: a list of tensors in the outer graph.
 
   Returns:
@@ -451,27 +451,16 @@
   Ensures this name is unique in the entire hierarchy.
 
   Args:
-    func_graph: The _FuncGraph.
+    func_graph: The FuncGraph.
 
   Returns:
     A string, the name to use for the gradient function.
   """
   name = "%s_grad" % func_graph.name
-
-  base_name = name
-  counter = 1
-  has_conflict = True
-  while has_conflict:
-    curr_graph = func_graph.outer_graph
-    has_conflict = curr_graph._is_function(name)
-    while not has_conflict and isinstance(curr_graph, _function.FuncGraph):
-      curr_graph = curr_graph.outer_graph
-      has_conflict = curr_graph._is_function(name)
-    if has_conflict:
-      name = "%s_%s" % (base_name, counter)
-      counter += 1
-
-  return name
+  outer_most_graph = func_graph
+  while isinstance(outer_most_graph, _function.FuncGraph):
+    outer_most_graph = outer_most_graph.outer_graph
+  return outer_most_graph.unique_name(name)
 
 
 def _check_same_outputs(true_graph, false_graph):
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e3c1aa3..0e20fad 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,7 +61,7 @@
 from tensorflow.python.util.tf_export import tf_export
 
 
-_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
 
 
 # We override the 'tuple' for a control flow op, so we keep python's
@@ -2026,7 +2026,7 @@
   ```
 
   """
-  if _ENABLE_COND_V2:
+  if ENABLE_COND_V2 and not context.executing_eagerly():
     return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
 
   # We needed to make true_fn/false_fn keyword arguments for
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 908e793..32d455b 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -242,11 +242,11 @@
 
   If `merge_repeated` is `True`, merge repeated classes in the output beams.
   This means that if consecutive entries in a beam are the same,
-  only the first of these is emitted.  That is, when the top path
-  is `A B B B B`, the return value is:
+  only the first of these is emitted.  That is, when the sequence is
+  `A B B * B * B` (where '*' is the blank label), the return value is:
 
     * `A B` if `merge_repeated = True`.
-    * `A B B B B` if `merge_repeated = False`.
+    * `A B B B` if `merge_repeated = False`.
 
   Args:
     inputs: 3-D `float` `Tensor`, size
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 7af2ca5..69c0fcb 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1229,7 +1229,8 @@
                dtype,
                shape=None,
                shared_name=None,
-               name="conditional_accumulator"):
+               name="conditional_accumulator",
+               reduction_type="MEAN"):
     """Creates a new ConditionalAccumulator.
 
     Args:
@@ -1238,9 +1239,14 @@
       shared_name: Optional. If non-empty, this accumulator will be shared under
         the given name across multiple sessions.
       name: Optional name for the accumulator.
+      reduction_type: Reduction type to use when taking the gradient.
     """
     accumulator_ref = gen_data_flow_ops.conditional_accumulator(
-        dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+        dtype=dtype,
+        shape=shape,
+        shared_name=shared_name,
+        name=name,
+        reduction_type=reduction_type)
     super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
 
   def apply_grad(self, grad, local_step=0, name=None):
@@ -1312,15 +1318,21 @@
     shared_name: Optional. If non-empty, this accumulator will be shared under
       the given name across multiple sessions.
     name: Optional name for the accumulator.
+    reduction_type: Reduction type to use when taking the gradient.
   """
 
   def __init__(self,
                dtype,
                shape=None,
                shared_name=None,
-               name="sparse_conditional_accumulator"):
+               name="sparse_conditional_accumulator",
+               reduction_type="MEAN"):
     accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
-        dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+        dtype=dtype,
+        shape=shape,
+        shared_name=shared_name,
+        name=name,
+        reduction_type=reduction_type)
     super(SparseConditionalAccumulator, self).__init__(dtype, shape,
                                                        accumulator_ref)
 
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index b65e64d..2e7aa30 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -1011,12 +1011,6 @@
   def _reduce_jacobian_det_over_event(
       self, y, ildj, min_event_ndims, event_ndims):
     """Reduce jacobian over event_ndims - min_event_ndims."""
-
-    if not self.is_constant_jacobian:
-      return math_ops.reduce_sum(
-          ildj,
-          self._get_event_reduce_dims(min_event_ndims, event_ndims))
-
     # In this case, we need to tile the Jacobian over the event and reduce.
     y_rank = array_ops.rank(y)
     y_shape = array_ops.shape(y)[
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index dd25fce..fbbacf2 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -69,7 +69,7 @@
   The Categorical distribution is closely related to the `OneHotCategorical` and
   `Multinomial` distributions.  The Categorical distribution can be intuited as
   generating samples according to `argmax{ OneHotCategorical(probs) }` itself
-  being identical to `argmax{ Multinomial(probs, total_count=1) }.
+  being identical to `argmax{ Multinomial(probs, total_count=1) }`.
 
   #### Mathematical Details
 
@@ -83,7 +83,7 @@
 
   The number of classes, `K`, must not exceed:
   - the largest integer representable by `self.dtype`, i.e.,
-    `2**(mantissa_bits+1)` (IEE754),
+    `2**(mantissa_bits+1)` (IEEE 754),
   - the maximum `Tensor` index, i.e., `2**31-1`.
 
   In other words,
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index ddf9442..578e7b7 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -446,6 +446,24 @@
     self._graph_parents = graph_parents
     self._name = name
 
+  @property
+  def _parameters(self):
+    return self._parameter_dict
+
+  @_parameters.setter
+  def _parameters(self, value):
+    """Intercept assignments to self._parameters to avoid reference cycles.
+
+    Parameters are often created using locals(), so we need to clean out any
+    references to `self` before assigning it to an attribute.
+
+    Args:
+      value: A dictionary of parameters to assign to the `_parameters` property.
+    """
+    if "self" in value:
+      del value["self"]
+    self._parameter_dict = value
+
   @classmethod
   def param_shapes(cls, sample_shape, name="DistributionParamShapes"):
     """Shapes of parameters given the desired shape of a call to `sample()`.
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 9fa8e27..1dc666e 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -19,10 +19,10 @@
 from __future__ import print_function
 
 # pylint: disable=unused-import
+from tensorflow.python.eager import function
 from tensorflow.python.eager.backprop import GradientTape
 from tensorflow.python.ops.custom_gradient import custom_gradient
 from tensorflow.python.ops.gradients_impl import AggregationMethod
 from tensorflow.python.ops.gradients_impl import gradients
 from tensorflow.python.ops.gradients_impl import hessians
 # pylint: enable=unused-import
-
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index a68f6802..196161c 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -31,7 +31,7 @@
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
+from tensorflow.python.framework import function as framework_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
@@ -58,6 +58,10 @@
 from tensorflow.python.util import compat
 from tensorflow.python.util.tf_export import tf_export
 
+# This is to avoid a circular dependency (eager.function depends on
+# gradients_impl). This is set in eager/function.py.
+_function = None
+
 # This is to avoid a circular dependency with cond_v2_impl.
 cond_v2_impl._gradients_impl = sys.modules[__name__]  # pylint: disable=protected-access
 
@@ -121,7 +125,7 @@
   Args:
     from_ops: list of Operations.
     reached_ops: set of Operations.
-    func_graphs: list of function._FuncGraphs. This method will traverse through
+    func_graphs: list of _function.FuncGraphs. This method will traverse through
       these functions if they capture from_ops or any reachable ops.
   """
   queue = collections.deque()
@@ -146,7 +150,7 @@
     to_ops: list of Operations.
     from_ops: list of Operations.
     colocate_gradients_with_ops: Python bool.  See docstring of gradients().
-    func_graphs: list of function._FuncGraphs. This method will traverse through
+    func_graphs: list of _function.FuncGraphs. This method will traverse through
       these functions if they capture from_ops or any reachable ops. This is
       useful if to_ops occur in a function and from_ops are in an outer function
       or graph.
@@ -256,6 +260,12 @@
               "Gradient type %s generated for complex-valued "
               "tensor %s with type %s must be real" % (dtypes.as_dtype(
                   grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+      elif y.dtype == dtypes.variant:
+        if grad_y.dtype != dtypes.variant:
+          raise TypeError(
+              "Gradient type %s generated for variant "
+              "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+                  grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
       else:
         raise TypeError(
             "Tensor %s with type %s must be numeric "
@@ -294,7 +304,7 @@
   if _IsTrainable(tensor):
     return True
   dtype = dtypes.as_dtype(tensor.dtype)
-  return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+  return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
 
 
 def _VerifyGeneratedGradients(grads, op):
@@ -441,6 +451,19 @@
       % target_op.name)
 
 
+def _IsFunction(graph):
+  return (isinstance(graph, _function.FuncGraph) or
+          isinstance(graph, framework_function._FuncGraph))  # pylint: disable=protected-access
+
+
+def _Captures(func_graph):
+  if isinstance(func_graph, _function.FuncGraph):
+    return func_graph.captures
+  else:
+    assert isinstance(func_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
+    return func_graph._captured  # pylint: disable=protected-access
+
+
 def _MaybeCaptured(t):
   """If t is a captured value placeholder, returns the original captured value.
 
@@ -448,11 +471,11 @@
     t: Tensor
 
   Returns:
-    A tensor, potentially from a different Graph/function._FuncGraph.
+    A tensor, potentially from a different Graph/_function.FuncGraph.
   """
   # pylint: disable=protected-access
-  if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder":
-    for input_t, placeholder_t in t.op.graph._captured.items():
+  if _IsFunction(t.op.graph) and t.op.type == "Placeholder":
+    for input_t, placeholder_t in _Captures(t.op.graph).items():
       if t == placeholder_t:
         return _MaybeCaptured(input_t)
   # pylint: enable=protected-access
@@ -470,10 +493,10 @@
 
   Returns:
     A list of tensors. The tensors may be from multiple
-    Graph/function._FuncGraphs if op is in a function._FuncGraph and has
+    Graph/_function.FuncGraphs if op is in a _function.FuncGraph and has
     captured inputs.
   """
-  if isinstance(op.graph, function._FuncGraph):  # pylint: disable=protected-access
+  if _IsFunction(op.graph):  # pylint: disable=protected-access
     # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
     # to a captured value. The algorithm needs to "see" `t` in this case, even
     # if it's a function input for a captured value, whereas usually we'd like
@@ -489,7 +512,7 @@
 
   Args:
     t: Tensor
-    func_graphs: a list of function._FuncGraphs that may have captured t.
+    func_graphs: a list of _function.FuncGraphs that may have captured t.
 
   Returns:
     A list of tensors. The tensors will be from the current graph and/or
@@ -497,7 +520,7 @@
   """
   consumers = t.consumers()
   for func in func_graphs:
-    for input_t, placeholder in func._captured.items():  # pylint: disable=protected-access
+    for input_t, placeholder in _Captures(func).items():
       if input_t == t:
         consumers.extend(_Consumers(placeholder, func_graphs))
   return consumers
@@ -616,9 +639,13 @@
   # ancestor graphs. This is necessary for correctly handling captured values.
   func_graphs = []
   curr_graph = src_graph
-  while isinstance(curr_graph, function._FuncGraph):  # pylint: disable=protected-access
+  while _IsFunction(curr_graph):
     func_graphs.append(curr_graph)
-    curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
+    if isinstance(curr_graph, _function.FuncGraph):
+      curr_graph = curr_graph.outer_graph
+    else:
+      assert isinstance(curr_graph, framework_function._FuncGraph)  # pylint: disable=protected-access
+      curr_graph = curr_graph._outer_graph  # pylint: disable=protected-access
 
   ys = _AsList(ys)
   xs = _AsList(xs)
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index fa9910b..6243be6 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -26,9 +26,10 @@
 from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
+from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
+from tensorflow.python.framework import function as framework_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
@@ -44,6 +45,7 @@
 from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
 from tensorflow.python.ops import gradients
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
@@ -369,8 +371,8 @@
 
   @classmethod
   def _GetFunc(cls, **kwargs):
-    return function.Defun(dtypes.float32, dtypes.float32, **
-                          kwargs)(cls.XSquarePlusB)
+    return framework_function.Defun(dtypes.float32, dtypes.float32, **
+                                    kwargs)(cls.XSquarePlusB)
 
   def _GetFuncGradients(self, f, x_value, b_value):
     x = constant_op.constant(x_value, name="x")
@@ -408,8 +410,9 @@
   def testFunctionGradientsWithGradFunc(self):
     g = ops.Graph()
     with g.as_default():
-      grad_func = function.Defun(dtypes.float32, dtypes.float32,
-                                 dtypes.float32)(self.XSquarePlusBGradient)
+      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+                                           dtypes.float32)(
+                                               self.XSquarePlusBGradient)
       f = self._GetFunc(grad_func=grad_func)
       # Get gradients (should add SymbolicGradient node for function, which
       # uses the grad_func above, which multiplies all gradients by 2).
@@ -430,8 +433,9 @@
   def testFunctionGradientWithGradFuncAndRegistration(self):
     g = ops.Graph()
     with g.as_default():
-      grad_func = function.Defun(dtypes.float32, dtypes.float32,
-                                 dtypes.float32)(self.XSquarePlusBGradient)
+      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
+                                           dtypes.float32)(
+                                               self.XSquarePlusBGradient)
       with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
         f = self._GetFunc(
             grad_func=grad_func, python_grad_func=self._PythonGradient)
@@ -441,7 +445,7 @@
     with ops.Graph().as_default():
       x = constant_op.constant(1.0, name="x")
 
-      @function.Defun()
+      @function.defun()
       def Foo():
         y = math_ops.multiply(x, 2.0, name="y")
         g = gradients_impl.gradients(y, x)
@@ -456,7 +460,7 @@
       x = constant_op.constant(1.0, name="x")
       y = math_ops.multiply(x, 2.0, name="y")
 
-      @function.Defun()
+      @framework_function.Defun()
       def Foo():
         g = gradients_impl.gradients(y, x)
         return g[0]
@@ -469,7 +473,7 @@
     with ops.Graph().as_default():
       var = resource_variable_ops.ResourceVariable(1.0, name="var")
 
-      @function.Defun()
+      @function.defun()
       def Foo():
         y = math_ops.multiply(var, 2.0, name="y")
         g = gradients_impl.gradients(y, var)
@@ -486,11 +490,11 @@
       x2 = constant_op.constant(2.0, name="x2")
       x3 = math_ops.multiply(x1, x2, name="x3")
 
-      @function.Defun()
+      @function.defun()
       def Outer():
         outer1 = array_ops.identity(x1, name="outer1")
 
-        @function.Defun()
+        @function.defun()
         def Inner():
           inner1 = array_ops.identity(outer1, name="inner1")
           inner2 = array_ops.identity(x2, name="inner2")
@@ -511,11 +515,11 @@
     with ops.Graph().as_default():
       x = constant_op.constant(1.0, name="x")
 
-      @function.Defun()
+      @function.defun()
       def Outer():
         y = math_ops.multiply(x, 2.0, name="y")
 
-        @function.Defun()
+        @function.defun()
         def Inner():
           z = math_ops.multiply(y, 3.0, name="z")
           g = gradients_impl.gradients(z, y)
@@ -1001,5 +1005,25 @@
     self._assert_indexed_slices_equal(total, result)
 
 
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+  def testDefaultGradYs(self):
+    with ops.Graph().as_default():
+      tl = list_ops.empty_tensor_list(
+          element_dtype=dtypes.float32,
+          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+      a = constant(1.0)
+      tl = list_ops.tensor_list_push_back(tl, a)
+
+      grad_tl = list_ops.empty_tensor_list(
+          element_dtype=dtypes.float32,
+          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+      with self.cached_session() as sess:
+        self.assertEquals(sess.run(grad), 5.)
+
+
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 1235694..de260f3 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -330,6 +330,8 @@
           lambda: image,
           name=scope
       )
+      if isinstance(result, tuple):
+        result = result[0]  # TODO(b/111124878) remove this logic (CondV2).
       return fix_image_flip_shape(image, result)
     elif shape.ndims == 4:
       uniform_random = random_ops.random_uniform(
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index f7502c4..795e6bb 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3657,6 +3657,47 @@
       scores = constant_op.constant([0.9])
       image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
 
+  def testDataTypes(self):
+    # Test case for GitHub issue 20199.
+    boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+                [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+    scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+    max_output_size_np = 3
+    iou_threshold_np = 0.5
+    # Note: There are multiple versions of non_max_suppression v2, v3, v4.
+    # gen_image_ops.non_max_suppression_v2:
+    for dtype in [np.float16, np.float32]:
+      with self.test_session():
+        boxes = constant_op.constant(boxes_np, dtype=dtype)
+        scores = constant_op.constant(scores_np, dtype=dtype)
+        max_output_size = constant_op.constant(max_output_size_np)
+        iou_threshold = constant_op.constant(iou_threshold_np)
+        selected_indices = gen_image_ops.non_max_suppression_v2(
+            boxes, scores, max_output_size, iou_threshold).eval()
+        self.assertAllClose(selected_indices, [3, 0, 5])
+    # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
+    for dtype in [np.float16, np.float32]:
+      with self.test_session():
+        boxes = constant_op.constant(boxes_np, dtype=dtype)
+        scores = constant_op.constant(scores_np, dtype=dtype)
+        max_output_size = constant_op.constant(max_output_size_np)
+        iou_threshold = constant_op.constant(iou_threshold_np)
+        selected_indices = image_ops.non_max_suppression(
+            boxes, scores, max_output_size, iou_threshold).eval()
+        self.assertAllClose(selected_indices, [3, 0, 5])
+    # gen_image_ops.non_max_suppression_v4.
+    score_threshold = float('-inf')
+    for dtype in [np.float16, np.float32]:
+      with self.test_session():
+        boxes = constant_op.constant(boxes_np, dtype=dtype)
+        scores = constant_op.constant(scores_np, dtype=dtype)
+        max_output_size = constant_op.constant(max_output_size_np)
+        iou_threshold = constant_op.constant(iou_threshold_np)
+        selected_indices, _ = gen_image_ops.non_max_suppression_v4(
+            boxes, scores, max_output_size, iou_threshold, score_threshold)
+        selected_indices = selected_indices.eval()
+        self.assertAllClose(selected_indices, [3, 0, 5])
+
 
 class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index fbc1350..f84785d 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -33,8 +33,9 @@
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_io_ops import *
-from tensorflow.python.util.tf_export import tf_export
 # pylint: enable=wildcard-import
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
 
 
 # pylint: disable=protected-access
@@ -95,7 +96,7 @@
       preferred_shard, name=name)
 
 
-@tf_export("ReaderBase")
+@tf_export(v1=["ReaderBase"])
 class ReaderBase(object):
   """Base class for different Reader types, that produce a record every step.
 
@@ -309,7 +310,7 @@
 ops.NotDifferentiable("ReaderReset")
 
 
-@tf_export("WholeFileReader")
+@tf_export(v1=["WholeFileReader"])
 class WholeFileReader(ReaderBase):
   """A Reader that outputs the entire contents of a file as a value.
 
@@ -324,6 +325,9 @@
   @end_compatibility
   """
 
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.data.Dataset.map(tf.read_file)`.")
   def __init__(self, name=None):
     """Create a WholeFileReader.
 
@@ -337,7 +341,7 @@
 ops.NotDifferentiable("WholeFileReader")
 
 
-@tf_export("TextLineReader")
+@tf_export(v1=["TextLineReader"])
 class TextLineReader(ReaderBase):
   """A Reader that outputs the lines of a file delimited by newlines.
 
@@ -351,6 +355,9 @@
   """
   # TODO(josh11b): Support serializing and restoring state.
 
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.data.TextLineDataset`.")
   def __init__(self, skip_header_lines=None, name=None):
     """Create a TextLineReader.
 
@@ -367,7 +374,7 @@
 ops.NotDifferentiable("TextLineReader")
 
 
-@tf_export("FixedLengthRecordReader")
+@tf_export(v1=["FixedLengthRecordReader"])
 class FixedLengthRecordReader(ReaderBase):
   """A Reader that outputs fixed-length records from a file.
 
@@ -380,6 +387,9 @@
   """
   # TODO(josh11b): Support serializing and restoring state.
 
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.data.FixedLengthRecordDataset`.")
   def __init__(self,
                record_bytes,
                header_bytes=None,
@@ -410,7 +420,7 @@
 ops.NotDifferentiable("FixedLengthRecordReader")
 
 
-@tf_export("TFRecordReader")
+@tf_export(v1=["TFRecordReader"])
 class TFRecordReader(ReaderBase):
   """A Reader that outputs the records from a TFRecords file.
 
@@ -423,6 +433,9 @@
   """
   # TODO(josh11b): Support serializing and restoring state.
 
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.data.TFRecordDataset`.")
   def __init__(self, name=None, options=None):
     """Create a TFRecordReader.
 
@@ -441,7 +454,7 @@
 ops.NotDifferentiable("TFRecordReader")
 
 
-@tf_export("LMDBReader")
+@tf_export(v1=["LMDBReader"])
 class LMDBReader(ReaderBase):
   """A Reader that outputs the records from a LMDB file.
 
@@ -452,6 +465,10 @@
   use `tf.data` to get data into your model.
   @end_compatibility
   """
+
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.contrib.data.LMDBDataset`.")
   def __init__(self, name=None, options=None):
     """Create a LMDBReader.
 
@@ -459,6 +476,7 @@
       name: A name for the operation (optional).
       options: A LMDBRecordOptions object (optional).
     """
+    del options
     rr = gen_io_ops.lmdb_reader(name=name)
     super(LMDBReader, self).__init__(rr)
 
@@ -466,7 +484,7 @@
 ops.NotDifferentiable("LMDBReader")
 
 
-@tf_export("IdentityReader")
+@tf_export(v1=["IdentityReader"])
 class IdentityReader(ReaderBase):
   """A Reader that outputs the queued work as both the key and value.
 
@@ -481,6 +499,9 @@
   @end_compatibility
   """
 
+  @deprecation.deprecated(
+      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+      "`tf.data.Dataset.map(...)`.")
   def __init__(self, name=None):
     """Create a IdentityReader.
 
diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py
new file mode 100644
index 0000000..86130a2
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_addition.py
@@ -0,0 +1,432 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Add one or more `LinearOperators` efficiently."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_diag
+from tensorflow.python.ops.linalg import linear_operator_full_matrix
+from tensorflow.python.ops.linalg import linear_operator_identity
+from tensorflow.python.ops.linalg import linear_operator_lower_triangular
+
+__all__ = []
+
+
+def add_operators(operators,
+                  operator_name=None,
+                  addition_tiers=None,
+                  name=None):
+  """Efficiently add one or more linear operators.
+
+  Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
+  operators `[B1, B2,...]` such that
+
+  ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
+
+  The operators `Bk` result by adding some of the `Ak`, as allowed by
+  `addition_tiers`.
+
+  Example of efficient adding of diagonal operators.
+
+  ```python
+  A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
+  A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
+
+  # Use two tiers, the first contains an Adder that returns Diag.  Since both
+  # A1 and A2 are Diag, they can use this Adder.  The second tier will not be
+  # used.
+  addition_tiers = [
+      [_AddAndReturnDiag()],
+      [_AddAndReturnMatrix()]]
+  B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
+
+  len(B_list)
+  ==> 1
+
+  B_list[0].__class__.__name__
+  ==> 'LinearOperatorDiag'
+
+  B_list[0].to_dense()
+  ==> [[3., 0.],
+       [0., 3.]]
+
+  B_list[0].name
+  ==> 'Add/A1__A2/'
+  ```
+
+  Args:
+    operators:  Iterable of `LinearOperator` objects with same `dtype`, domain
+      and range dimensions, and broadcastable batch shapes.
+    operator_name:  String name for returned `LinearOperator`.  Defaults to
+      concatenation of "Add/A__B/" that indicates the order of addition steps.
+    addition_tiers:  List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
+      is a list of `Adder` objects.  This function attempts to do all additions
+      in tier `i` before trying tier `i + 1`.
+    name:  A name for this `Op`.  Defaults to `add_operators`.
+
+  Returns:
+    Subclass of `LinearOperator`.  Class and order of addition may change as new
+      (and better) addition strategies emerge.
+
+  Raises:
+    ValueError:  If `operators` argument is empty.
+    ValueError:  If shapes are incompatible.
+  """
+  # Default setting
+  if addition_tiers is None:
+    addition_tiers = _DEFAULT_ADDITION_TIERS
+
+  # Argument checking.
+  check_ops.assert_proper_iterable(operators)
+  operators = list(reversed(operators))
+  if len(operators) < 1:
+    raise ValueError(
+        "Argument 'operators' must contain at least one operator.  "
+        "Found: %s" % operators)
+  if not all(
+      isinstance(op, linear_operator.LinearOperator) for op in operators):
+    raise TypeError(
+        "Argument 'operators' must contain only LinearOperator instances.  "
+        "Found: %s" % operators)
+  _static_check_for_same_dimensions(operators)
+  _static_check_for_broadcastable_batch_shape(operators)
+
+  graph_parents = []
+  for operator in operators:
+    graph_parents.extend(operator.graph_parents)
+
+  with ops.name_scope(name or "add_operators", values=graph_parents):
+
+    # Additions done in one of the tiers.  Try tier 0, 1,...
+    ops_to_try_at_next_tier = list(operators)
+    for tier in addition_tiers:
+      ops_to_try_at_this_tier = ops_to_try_at_next_tier
+      ops_to_try_at_next_tier = []
+      while ops_to_try_at_this_tier:
+        op1 = ops_to_try_at_this_tier.pop()
+        op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
+        if op2 is not None:
+          # Will try to add the result of this again at this same tier.
+          new_operator = adder.add(op1, op2, operator_name)
+          ops_to_try_at_this_tier.append(new_operator)
+        else:
+          ops_to_try_at_next_tier.append(op1)
+
+    return ops_to_try_at_next_tier
+
+
+def _pop_a_match_at_tier(op1, operator_list, tier):
+  # Search from the back of list to the front in order to create nice default
+  # order of operations.
+  for i in range(1, len(operator_list) + 1):
+    op2 = operator_list[-i]
+    for adder in tier:
+      if adder.can_add(op1, op2):
+        return operator_list.pop(-i), adder
+  return None, None
+
+
+def _infer_hints_allowing_override(op1, op2, hints):
+  """Infer hints from op1 and op2.  hints argument is an override.
+
+  Args:
+    op1:  LinearOperator
+    op2:  LinearOperator
+    hints:  _Hints object holding "is_X" boolean hints to use for returned
+      operator.
+      If some hint is None, try to set using op1 and op2.  If the
+      hint is provided, ignore op1 and op2 hints.  This allows an override
+      of previous hints, but does not allow forbidden hints (e.g. you still
+      cannot say a real diagonal operator is not self-adjoint.
+
+  Returns:
+    _Hints object.
+  """
+  hints = hints or _Hints()
+  # If A, B are self-adjoint, then so is A + B.
+  if hints.is_self_adjoint is None:
+    is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
+  else:
+    is_self_adjoint = hints.is_self_adjoint
+
+  # If A, B are positive definite, then so is A + B.
+  if hints.is_positive_definite is None:
+    is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
+  else:
+    is_positive_definite = hints.is_positive_definite
+
+  # A positive definite operator is always non-singular.
+  if is_positive_definite and hints.is_positive_definite is None:
+    is_non_singular = True
+  else:
+    is_non_singular = hints.is_non_singular
+
+  return _Hints(
+      is_non_singular=is_non_singular,
+      is_self_adjoint=is_self_adjoint,
+      is_positive_definite=is_positive_definite)
+
+
+def _static_check_for_same_dimensions(operators):
+  """ValueError if operators determined to have different dimensions."""
+  if len(operators) < 2:
+    return
+
+  domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
+                       if op.domain_dimension.value is not None]
+  if len(set(value for name, value in domain_dimensions)) > 1:
+    raise ValueError("Operators must have the same domain dimension. Found: %s"
+                     % domain_dimensions)
+
+  range_dimensions = [(op.name, op.range_dimension.value) for op in operators
+                      if op.range_dimension.value is not None]
+  if len(set(value for name, value in range_dimensions)) > 1:
+    raise ValueError("Operators must have the same range dimension. Found: %s" %
+                     range_dimensions)
+
+
+def _static_check_for_broadcastable_batch_shape(operators):
+  """ValueError if operators determined to have non-broadcastable shapes."""
+  if len(operators) < 2:
+    return
+
+  # This will fail if they cannot be broadcast together.
+  batch_shape = operators[0].batch_shape
+  for op in operators[1:]:
+    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
+
+
+class _Hints(object):
+  """Holds 'is_X' flags that every LinearOperator is initialized with."""
+
+  def __init__(self,
+               is_non_singular=None,
+               is_positive_definite=None,
+               is_self_adjoint=None):
+    self.is_non_singular = is_non_singular
+    self.is_positive_definite = is_positive_definite
+    self.is_self_adjoint = is_self_adjoint
+
+
+################################################################################
+# Classes to add two linear operators.
+################################################################################
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _Adder(object):
+  """Abstract base class to add two operators.
+
+  Each `Adder` acts independently, adding everything it can, paying no attention
+  as to whether another `Adder` could have done the addition more efficiently.
+  """
+
+  @property
+  def name(self):
+    return self.__class__.__name__
+
+  @abc.abstractmethod
+  def can_add(self, op1, op2):
+    """Returns `True` if this `Adder` can add `op1` and `op2`.  Else `False`."""
+    pass
+
+  @abc.abstractmethod
+  def _add(self, op1, op2, operator_name, hints):
+    # Derived classes can assume op1 and op2 have been validated, e.g. they have
+    # the same dtype, and their domain/range dimensions match.
+    pass
+
+  def add(self, op1, op2, operator_name, hints=None):
+    """Return new `LinearOperator` acting like `op1 + op2`.
+
+    Args:
+      op1:  `LinearOperator`
+      op2:  `LinearOperator`, with `shape` and `dtype` such that adding to
+        `op1` is allowed.
+      operator_name:  `String` name to give to returned `LinearOperator`
+      hints:  `_Hints` object.  Returned `LinearOperator` will be created with
+        these hints.
+
+    Returns:
+      `LinearOperator`
+    """
+    updated_hints = _infer_hints_allowing_override(op1, op2, hints)
+
+    if operator_name is None:
+      operator_name = "Add/" + op1.name + "__" + op2.name + "/"
+
+    values = op1.graph_parents + op2.graph_parents
+    scope_name = self.name
+    if scope_name.startswith("_"):
+      scope_name = scope_name[1:]
+    with ops.name_scope(scope_name, values=values):
+      return self._add(op1, op2, operator_name, updated_hints)
+
+
+class _AddAndReturnScaledIdentity(_Adder):
+  """Handles additions resulting in an Identity family member.
+
+  The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
+  is closed under addition.  This `Adder` respects that, and returns an Identity
+  """
+
+  def can_add(self, op1, op2):
+    types = {_type(op1), _type(op2)}
+    return not types.difference(_IDENTITY_FAMILY)
+
+  def _add(self, op1, op2, operator_name, hints):
+    # Will build a LinearOperatorScaledIdentity.
+
+    if _type(op1) == _SCALED_IDENTITY:
+      multiplier_1 = op1.multiplier
+    else:
+      multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
+
+    if _type(op2) == _SCALED_IDENTITY:
+      multiplier_2 = op2.multiplier
+    else:
+      multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
+
+    return linear_operator_identity.LinearOperatorScaledIdentity(
+        num_rows=op1.range_dimension_tensor(),
+        multiplier=multiplier_1 + multiplier_2,
+        is_non_singular=hints.is_non_singular,
+        is_self_adjoint=hints.is_self_adjoint,
+        is_positive_definite=hints.is_positive_definite,
+        name=operator_name)
+
+
+class _AddAndReturnDiag(_Adder):
+  """Handles additions resulting in a Diag operator."""
+
+  def can_add(self, op1, op2):
+    types = {_type(op1), _type(op2)}
+    return not types.difference(_DIAG_LIKE)
+
+  def _add(self, op1, op2, operator_name, hints):
+    return linear_operator_diag.LinearOperatorDiag(
+        diag=op1.diag_part() + op2.diag_part(),
+        is_non_singular=hints.is_non_singular,
+        is_self_adjoint=hints.is_self_adjoint,
+        is_positive_definite=hints.is_positive_definite,
+        name=operator_name)
+
+
+class _AddAndReturnTriL(_Adder):
+  """Handles additions resulting in a TriL operator."""
+
+  def can_add(self, op1, op2):
+    types = {_type(op1), _type(op2)}
+    return not types.difference(_DIAG_LIKE.union({_TRIL}))
+
+  def _add(self, op1, op2, operator_name, hints):
+    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+      op_add_to_tensor, op_other = op1, op2
+    else:
+      op_add_to_tensor, op_other = op2, op1
+
+    return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
+        tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+        is_non_singular=hints.is_non_singular,
+        is_self_adjoint=hints.is_self_adjoint,
+        is_positive_definite=hints.is_positive_definite,
+        name=operator_name)
+
+
+class _AddAndReturnMatrix(_Adder):
+  """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
+
+  def can_add(self, op1, op2):  # pylint: disable=unused-argument
+    return isinstance(op1, linear_operator.LinearOperator) and isinstance(
+        op2, linear_operator.LinearOperator)
+
+  def _add(self, op1, op2, operator_name, hints):
+    if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+      op_add_to_tensor, op_other = op1, op2
+    else:
+      op_add_to_tensor, op_other = op2, op1
+    return linear_operator_full_matrix.LinearOperatorFullMatrix(
+        matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+        is_non_singular=hints.is_non_singular,
+        is_self_adjoint=hints.is_self_adjoint,
+        is_positive_definite=hints.is_positive_definite,
+        name=operator_name)
+
+
+################################################################################
+# Constants designating types of LinearOperators
+################################################################################
+
+# Type name constants for LinearOperator classes.
+_IDENTITY = "identity"
+_SCALED_IDENTITY = "scaled_identity"
+_DIAG = "diag"
+_TRIL = "tril"
+_MATRIX = "matrix"
+
+# Groups of operators.
+_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
+_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
+# operators with an efficient .add_to_tensor() method.
+_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
+
+
+def _type(operator):
+  """Returns the type name constant (e.g. _TRIL) for operator."""
+  if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
+    return _DIAG
+  if isinstance(operator,
+                linear_operator_lower_triangular.LinearOperatorLowerTriangular):
+    return _TRIL
+  if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
+    return _MATRIX
+  if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
+    return _IDENTITY
+  if isinstance(operator,
+                linear_operator_identity.LinearOperatorScaledIdentity):
+    return _SCALED_IDENTITY
+  raise TypeError("Operator type unknown: %s" % operator)
+
+
+################################################################################
+# Addition tiers:
+# We attempt to use Adders in tier K before K+1.
+#
+# Organize tiers to
+#   (i) reduce O(..) complexity of forming final operator, and
+#   (ii) produce the "most efficient" final operator.
+# Dev notes:
+#  * Results of addition at tier K will be added at tier K or higher.
+#  * Tiers may change, and we warn the user that it may change.
+################################################################################
+
+# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
+# dense matrix.  So it is sometimes very inefficient.
+_DEFAULT_ADDITION_TIERS = [
+    [_AddAndReturnScaledIdentity()],
+    [_AddAndReturnDiag()],
+    [_AddAndReturnTriL()],
+    [_AddAndReturnMatrix()],
+]
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index c367ed2..021ef47 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -160,20 +160,20 @@
     `block_depth = 1` means `A` is symmetric circulant.  For example,
 
     ```
-    A = |x y z y|
-        |y x y z|
-        |z y x y|
-        |y z y x|
+    A = |w z y x|
+        |x w z y|
+        |y x w z|
+        |z y x w|
     ```
 
     `block_depth = 2` means `A` is block symmetric circulant with symemtric
-    circulant blocks.  For example, with `X`, `Y`, `Z` symmetric circulant,
+    circulant blocks.  For example, with `W`, `X`, `Y`, `Z` symmetric circulant,
 
     ```
-    A = |X Y Z Y|
-        |Y X Y Z|
-        |Z Y X Y|
-        |Y Z Y X|
+    A = |W Z Y X|
+        |X W Z Y|
+        |Y X W Z|
+        |Z Y X W|
     ```
 
     `block_depth = 3` means `A` is block symmetric circulant with block
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9b0ab00..7c59232 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1088,9 +1088,6 @@
   `x // y` floor division in Python 3 and in Python 2.7 with
   `from __future__ import division`.
 
-  Note that for efficiency, `floordiv` uses C semantics for negative numbers
-  (unlike Python and Numpy).
-
   `x` and `y` must have the same type, and the result will have the same type
   as well.
 
@@ -1100,7 +1097,7 @@
     name: A name for the operation (optional).
 
   Returns:
-    `x / y` rounded down (except possibly towards zero for negative integers).
+    `x / y` rounded down.
 
   Raises:
     TypeError: If the inputs are complex.
@@ -2571,7 +2568,7 @@
 
 @tf_export("unsorted_segment_mean")
 def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
-  r""" Computes the mean along segments of a tensor.
+  r"""Computes the mean along segments of a tensor.
 
   Read [the section on
   segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
@@ -2582,17 +2579,26 @@
   Instead of computing the sum over segments, it computes the mean of all
   entries belonging to a segment such that:
 
-  \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such
-  that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
-  of id \\i\\.
+  \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples
+  `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of
+  occurrences of id \\i\\.
 
   If there is no entry for a given segment ID `i`, it outputs 0.
 
-  segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-  first dimension.
+  If the given segment ID `i` is negative, the value is dropped and will not
+  be added to the sum of the segment.
 
-  output: Has same shape as data, except for dimension 0 which
-  has size `num_segments`.
+  Args:
+    data: A `Tensor` with floating point or complex dtype.
+    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+    num_segments: An integer scalar `Tensor`.  The number of distinct
+      segment IDs.
+    name: A name for the operation (optional).
+
+  Returns:
+    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
+    dimensions, which are replaced with a single dimension which has size
+   `num_segments`.
   """
   with ops.name_scope(name, "UnsortedSegmentMean"):
     data = ops.convert_to_tensor(data)
@@ -2615,20 +2621,29 @@
   Additionally to computing the sum over segments, it divides the results by
   sqrt(N).
 
-  \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such
-  that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
-  of id \\i\\.
+  \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over
+  tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the
+  number of occurrences of id \\i\\.
 
   If there is no entry for a given segment ID `i`, it outputs 0.
 
   Note that this op only supports floating point and complex dtypes,
   due to tf.sqrt only supporting these types.
 
-  segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-  first dimension.
+  If the given segment ID `i` is negative, the value is dropped and will not
+  be added to the sum of the segment.
 
-  output: Has same shape as data, except for dimension 0 which
-  has size `num_segments`.
+  Args:
+    data: A `Tensor` with floating point or complex dtype.
+    segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+    num_segments: An integer scalar `Tensor`.  The number of distinct
+      segment IDs.
+    name: A name for the operation (optional).
+
+  Returns:
+    A `Tensor`.  Has same shape as data, except for the first `segment_ids.rank`
+    dimensions, which are replaced with a single dimension which has size
+   `num_segments`.
   """
   with ops.name_scope(name, "UnsortedSegmentSqrtN"):
     data = ops.convert_to_tensor(data)
@@ -2888,22 +2903,24 @@
         free_dims_static = None
       shape_a = array_ops.shape(a)
       rank_a = array_ops.rank(a)
-      axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
-      axes = cast(axes >= 0, dtypes.int32) * axes + cast(
-          axes < 0, dtypes.int32) * (
-              axes + rank_a)
-      free, _ = array_ops.setdiff1d(range(rank_a), axes)
-      free_dims = array_ops.gather(shape_a, free)
-      axes_dims = array_ops.gather(shape_a, axes)
-      prod_free_dims = reduce_prod(free_dims)
-      prod_axes_dims = reduce_prod(axes_dims)
-      perm = array_ops.concat([axes_dims, free_dims], 0)
-      if flipped:
-        perm = array_ops.concat([axes, free], 0)
-        new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
-      else:
-        perm = array_ops.concat([free, axes], 0)
-        new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
+      # TODO(b/115583659): Automate this.
+      with ops.device("/cpu:0"):
+        axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+        axes = cast(axes >= 0, dtypes.int32) * axes + cast(
+            axes < 0, dtypes.int32) * (
+                axes + rank_a)
+        free, _ = array_ops.setdiff1d(range(rank_a), axes)
+        free_dims = array_ops.gather(shape_a, free)
+        axes_dims = array_ops.gather(shape_a, axes)
+        prod_free_dims = reduce_prod(free_dims)
+        prod_axes_dims = reduce_prod(axes_dims)
+        perm = array_ops.concat([axes_dims, free_dims], 0)
+        if flipped:
+          perm = array_ops.concat([axes, free], 0)
+          new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
+        else:
+          perm = array_ops.concat([free, axes], 0)
+          new_shape = array_ops.stack([prod_free_dims, prod_axes_dims])
       reshaped_a = array_ops.reshape(array_ops.transpose(a, perm), new_shape)
       return reshaped_a, free_dims, free_dims_static
 
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 474e0bb..2526e6f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -510,7 +510,7 @@
 
     # Recover channel information for output shape if channels are not last.
     if self.data_format is not None and self.data_format.startswith("NC"):
-      if not result_converted.shape[1].value:
+      if not result_converted.shape[1].value and filter is not None:
         output_shape = result_converted.shape.as_list()
         output_shape[1] = filter.shape[-1]
         result_converted.set_shape(output_shape)
@@ -2454,7 +2454,7 @@
   returned to the caller.
 
   Args:
-    value: A 3D `Tensor`.  Must be of type `float16` or `float32`.
+    value: A 3D `Tensor`.  Must be of type `float16`, `float32`, or `float64`.
     filters: A 3D `Tensor`.  Must have the same type as `value`.
     stride: An `integer`.  The number of entries by which
       the filter is moved right at each step.
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index c0e66cb..d403b0c 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -1259,7 +1259,7 @@
                                         [3])  # [0, 2, 0]
 
     pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(pfor, feed_dict={num_iters: 3})
 
   def test_sparse_result_none_stacked(self):
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index f9cf16f..628c676 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -356,7 +356,7 @@
     self.run_and_assert_equal(answer, jacobian_while)
 
   def test_jacobian_unknown_shape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32, shape=[None, None])
       y = math_ops.matmul(x, x, transpose_a=True)
       jacobian_pfor = gradients.jacobian(y, x, use_pfor=True)
@@ -381,7 +381,7 @@
       gradients.batch_jacobian(y, x, use_pfor=True)
 
   def test_batch_jacobian_bad_unknown_shapes(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       y = array_ops.concat([x, x], axis=0)
       jacobian = gradients.batch_jacobian(y, x)
@@ -402,7 +402,7 @@
     self.run_and_assert_equal(answer, batch_jacobian_while)
 
   def test_batch_jacobian_unknown_shape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       y = x * x
       batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True)
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index 3c914f6..f9153b6 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -21,8 +21,6 @@
 
 import collections
 
-from absl import flags
-
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -41,6 +39,7 @@
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import flags
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
 
@@ -2013,6 +2012,7 @@
 @RegisterPForWithArgs("ReluGrad")
 @RegisterPForWithArgs("TanhGrad")
 @RegisterPForWithArgs("SigmoidGrad")
+@RegisterPForWithArgs("SoftplusGrad")
 def _convert_grads(pfor_input, op_type, *args, **kw_args):
   del args
   del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 8224097..bb8da31 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -1584,7 +1584,8 @@
     record_defaults: A list of `Tensor` objects with specific types.
       Acceptable types are `float32`, `float64`, `int32`, `int64`, `string`.
       One tensor per column of the input record, with either a
-      scalar default value for that column or empty if the column is required.
+      scalar default value for that column or an empty vector if the column is
+      required.
     field_delim: An optional `string`. Defaults to `","`.
       char delimiter to separate fields in a record.
     use_quote_delim: An optional `bool`. Defaults to `True`.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 4800352..55c2eb5 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -750,7 +750,7 @@
 
   def _read_variable_op(self):
     if self.trainable:
-      tape.watch_variable(self)
+      tape.variable_accessed(self)
     result = gen_resource_variable_ops.read_variable_op(self._handle,
                                                         self._dtype)
     if not context.executing_eagerly():
@@ -781,7 +781,7 @@
     """Reads the value of this variable sparsely, using `gather`."""
     with ops.name_scope("Gather" if name is None else name) as name:
       if self.trainable:
-        tape.watch_variable(self)
+        tape.variable_accessed(self)
       value = gen_resource_variable_ops.resource_gather(
           self._handle, indices, dtype=self._dtype, name=name)
     return array_ops.identity(value)
@@ -949,12 +949,12 @@
 
   def _lazy_read(self, op):
     if self.trainable:
-      tape.watch_variable(self)
+      tape.variable_accessed(self)
     return _UnreadVariable(
         handle=self._handle, dtype=self.dtype, shape=self._shape,
         in_graph_mode=self._in_graph_mode,
         deleter=self._handle_deleter if not self._in_graph_mode else None,
-        parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id)
+        parent_op=op, unique_id=self._unique_id)
 
   def assign(self, value, use_locking=None, name=None, read_value=True):
     """Assigns a new value to this variable.
@@ -1293,8 +1293,7 @@
   """
 
   def __init__(self, handle, dtype,  # pylint: disable=super-init-not-called
-               shape, in_graph_mode, deleter, parent_op, parent_name,
-               unique_id):
+               shape, in_graph_mode, deleter, parent_op, unique_id):
     # We do not call super init on purpose.
     self._trainable = False
     self._save_slice_info = None
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 5c00d92..5a3a5cc 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -709,6 +709,10 @@
   Raises:
     ValueError: If the input depth cannot be inferred via shape inference
       from the inputs.
+    ValueError: If time_step is not the same for all the elements in the
+      inputs.
+    ValueError: If batch_size is not the same for all the elements in the
+      inputs.
   """
   state = initial_state
   assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index fa13568..3e19183 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -428,7 +428,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % inputs_shape)
+                       % str(input_shape))
 
     input_depth = inputs_shape[-1]
     self._kernel = self.add_variable(
@@ -525,7 +525,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % inputs_shape)
+                       % str(input_shape))
 
     input_depth = inputs_shape[-1]
     self._gate_kernel = self.add_variable(
@@ -705,7 +705,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % inputs_shape)
+                       % str(input_shape))
 
     input_depth = inputs_shape[-1]
     h_depth = self._num_units
@@ -783,10 +783,10 @@
 
   The default non-peephole implementation is based on:
 
-    http://www.bioinf.jku.at/publications/older/2604.pdf
+    https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
 
-  S. Hochreiter and J. Schmidhuber.
-  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+  Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+  "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
 
   The peephole implementation is based on:
 
@@ -908,7 +908,7 @@
   def build(self, inputs_shape):
     if inputs_shape[-1] is None:
       raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
-                       % inputs_shape)
+                       % str(input_shape))
 
     input_depth = inputs_shape[-1]
     h_depth = self._num_units if self._num_proj is None else self._num_proj
@@ -954,7 +954,7 @@
     """Run one step of LSTM.
 
     Args:
-      inputs: input Tensor, 2D, `[batch, num_units].
+      inputs: input Tensor, must be 2-D, `[batch, input_size]`.
       state: if `state_is_tuple` is False, this must be a state Tensor,
         `2-D, [batch, state_size]`.  If `state_is_tuple` is True, this must be a
         tuple of state Tensors, both `2-D`, with column sizes `c_state` and
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 8d66de6..2ec4b54 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -287,19 +287,19 @@
 
 # TODO(akshayka): Implement higher-order derivatives.
 @ops.RegisterGradient("EagerPyFunc")
-def _EagerPyFuncGrad(op, dy):
+def _EagerPyFuncGrad(op, *dy):
   """Computes the gradient of an EagerPyFunc."""
 
   token = op.get_attr("token")
 
-  def eagerly_executed_grad(dy):
+  def eagerly_executed_grad(*dy):
     tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
     return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
 
   with ops.control_dependencies(op.outputs):
     return _internal_py_func(
         func=eagerly_executed_grad,
-        inp=[dy] if isinstance(dy, ops.Tensor) else dy,
+        inp=dy,
         Tout=[tensor.dtype for tensor in op.inputs],
         eager=True,
         is_grad_func=True)
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index c832ba4..b2c6937 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -41,12 +41,41 @@
 from tensorflow.python.util.tf_export import tf_export
 # pylint: enable=wildcard-import
 
+
+# pylint: disable=redefined-builtin
+def regex_full_match(input, pattern, name=None):
+  r"""Match elements of `input` with regex `pattern`.
+
+  Args:
+    input: string `Tensor`, the source strings to process.
+    pattern: string or scalar string `Tensor`, regular expression to use,
+      see more details at https://github.com/google/re2/wiki/Syntax
+    name: Name of the op.
+
+  Returns:
+    bool `Tensor` of the same shape as `input` with match results.
+  """
+  # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+  if not compat.forward_compatible(2018, 11, 10):
+    return gen_string_ops.regex_full_match(
+        input=input, pattern=pattern, name=name)
+  if isinstance(pattern, util_compat.bytes_or_text_types):
+    # When `pattern` is static through the life of the op we can
+    # use a version which performs the expensive regex compilation once at
+    # creation time.
+    return gen_string_ops.static_regex_full_match(
+        input=input, pattern=pattern, name=name)
+  return gen_string_ops.regex_full_match(
+      input=input, pattern=pattern, name=name)
+
+regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
+
 # Expose regex_full_match in strings namespace
 tf_export("strings.regex_full_match")(regex_full_match)
 
 
 def regex_replace(source, pattern, rewrite, replace_global=True):
-  r"""Replace elements of `source` matching regex `pattern with `rewrite`.
+  r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
 
   Args:
     source: string `Tensor`, the source strings to process.
@@ -61,11 +90,6 @@
   Returns:
     string `Tensor` of the same shape as `source` with specified replacements.
   """
-  # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
-  if not compat.forward_compatible(2018, 10, 10):
-    return gen_string_ops.regex_replace(
-        input=source, pattern=pattern,
-        rewrite=rewrite, replace_global=replace_global)
   if (isinstance(pattern, util_compat.bytes_or_text_types) and
       isinstance(rewrite, util_compat.bytes_or_text_types)):
     # When `pattern` and `rewrite` are static through the life of the op we can
@@ -128,6 +152,7 @@
   shape.set_shape([2])
   return sparse_tensor.SparseTensor(indices, values, shape)
 
+
 @tf_export("strings.split")
 def string_split_v2(source, sep=None, maxsplit=-1):
   """Split elements of `source` based on `sep` into a `SparseTensor`.
@@ -170,7 +195,7 @@
     second column corresponds to the index of the split component in this row.
   """
   if sep is None:
-    sep = ''
+    sep = ""
   sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
   source = ops.convert_to_tensor(source, dtype=dtypes.string)
 
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index 45de047..5927bc2 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -33,6 +33,7 @@
 from tensorflow.python.lib.io.file_io import stat as Stat
 from tensorflow.python.lib.io.file_io import walk as Walk
 # pylint: enable=unused-import
+from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -62,6 +63,7 @@
   invocations in network filesystems).
   """
 
+  @deprecated(None, 'Use tf.gfile.GFile.')
   def __init__(self, name, mode='r'):
     super(FastGFile, self).__init__(name=name, mode=mode)
 
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index a31861a..c411a58 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -52,9 +52,10 @@
 %rename("%s") TFE_Py_TapeSetShouldRecord;
 %rename("%s") TFE_Py_TapeSetDeleteTrace;
 %rename("%s") TFE_Py_TapeSetRecordOperation;
-%rename("%s") TFE_Py_TapeSetWatchVariable;
 %rename("%s") TFE_Py_TapeGradient;
+%rename("%s") TFE_Py_TapeVariableAccessed;
 %rename("%s") TFE_Py_TapeWatch;
+%rename("%s") TFE_Py_TapeWatchVariable;
 %rename("%s") TFE_Py_TapeWatchedVariables;
 %rename("%s") TFE_NewContextOptions;
 %rename("%s") TFE_ContextOptionsSetConfig;
@@ -65,6 +66,7 @@
 %rename("%s") TFE_Py_TensorShapeOnDevice;
 %rename("%s") TFE_ContextStartStep;
 %rename("%s") TFE_ContextEndStep;
+%rename("%s") TFE_Py_RegisterVSpace;
 
 %{
 #include "tensorflow/python/eager/pywrap_tfe.h"
@@ -186,7 +188,10 @@
                         "outputs of the operation)");
   }
   $1 = &temp;
-  $1->resize(PyInt_AsLong($input), nullptr);
+  long sz = PyInt_AsLong($input);
+  if (sz > 0) {
+    $1->resize(PyInt_AsLong($input), nullptr);
+  }
 }
 
 // Create new Status object.
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 7a37eda..c9bc33e 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -225,6 +225,7 @@
         ":signature_constants",
         ":utils",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:framework_ops",
         "//tensorflow/python:util",
     ],
 )
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index 5eeaf73..fe69f3b 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -91,10 +91,17 @@
 
 #### Tags
 Each meta graph added to the SavedModel must be annotated with user specified
-tags. The tags provide a means to identify the specific meta graph to load and
-restore, along with the shared set of variables and assets. These tags
-typically annotate a MetaGraph with its functionality (e.g. serving or
-training), and possibly hardware specific aspects such as GPU.
+tags, which reflect the meta graph capabilities or use-cases.
+More specifically, these tags typically annotate a meta graph with its
+functionality (e.g. serving or training), and possibly hardware specific aspects
+such as GPU.
+In the SavedModel, the meta graph def whose tag-set exactly matches those
+specified in the loader API, will be the one loaded by the loader.
+If no meta graph def is found matching the specified tags, an error is returned.
+For example, a loader with a requirement to serve on GPU hardware would be able
+to load only meta graph annotated with tags='serve,gpu' by specifying this set
+of tags in tensorflow::LoadSavedModel(...).
+
 
 #### Usage
 The typical usage of `builder` is as follows:
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index f8ad788..37f927f 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -21,9 +21,7 @@
 
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.saved_model import utils
 from tensorflow.python.util.tf_export import tf_export
@@ -316,80 +314,3 @@
 
   return True
 
-
-def _get_shapes_from_tensor_info_dict(tensor_info_dict):
-  """Returns a map of keys to TensorShape objects.
-
-  Args:
-    tensor_info_dict: map with TensorInfo proto as values.
-
-  Returns:
-    Map with corresponding TensorShape objects as values.
-  """
-  return {
-      key: tensor_shape.TensorShape(tensor_info.tensor_shape)
-      for key, tensor_info in tensor_info_dict.items()
-  }
-
-
-def _get_types_from_tensor_info_dict(tensor_info_dict):
-  """Returns a map of keys to DType objects.
-
-  Args:
-    tensor_info_dict: map with TensorInfo proto as values.
-
-  Returns:
-    Map with corresponding DType objects as values.
-  """
-  return {
-      key: dtypes.DType(tensor_info.dtype)
-      for key, tensor_info in tensor_info_dict.items()
-  }
-
-
-def get_signature_def_input_shapes(signature):
-  """Returns map of parameter names to their shapes.
-
-  Args:
-    signature: SignatureDef proto.
-
-  Returns:
-    Map from string to TensorShape objects.
-  """
-  return _get_shapes_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_input_types(signature):
-  """Returns map of output names to their types.
-
-  Args:
-    signature: SignatureDef proto.
-
-  Returns:
-    Map from string to DType objects.
-  """
-  return _get_types_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_output_shapes(signature):
-  """Returns map of output names to their shapes.
-
-  Args:
-    signature: SignatureDef proto.
-
-  Returns:
-    Map from string to TensorShape objects.
-  """
-  return _get_shapes_from_tensor_info_dict(signature.outputs)
-
-
-def get_signature_def_output_types(signature):
-  """Returns map of output names to their types.
-
-  Args:
-    signature: SignatureDef proto.
-
-  Returns:
-    Map from string to DType objects.
-  """
-  return _get_types_from_tensor_info_dict(signature.outputs)
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index ebc5450..18c55d8 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -275,44 +275,6 @@
     self.assertEqual(method_name, signature_def.method_name)
     self.assertEqual(3, len(signature_def.outputs))
 
-  def testGetShapeAndTypes(self):
-    inputs = {
-        "input-1": constant_op.constant(["a", "b"]),
-        "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
-    }
-    outputs = {
-        "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
-        "output-2": constant_op.constant([["b"]]),
-    }
-    signature_def = _make_signature(inputs, outputs)
-    self.assertEqual(
-        signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
-        {"input-1": [2], "input-2": [10, 11]})
-    self.assertEqual(
-        signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
-        {"output-1": [10, 32], "output-2": [1, 1]})
-    self.assertEqual(
-        signature_def_utils_impl.get_signature_def_input_types(signature_def),
-        {"input-1": dtypes.string, "input-2": dtypes.float32})
-    self.assertEqual(
-        signature_def_utils_impl.get_signature_def_output_types(signature_def),
-        {"output-1": dtypes.float32, "output-2": dtypes.string})
-
-  def testGetNonFullySpecifiedShapes(self):
-    outputs = {
-        "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
-        "output-2": array_ops.sparse_placeholder(dtypes.float32),
-    }
-    signature_def = _make_signature({}, outputs)
-    shapes = signature_def_utils_impl.get_signature_def_output_shapes(
-        signature_def)
-    self.assertEqual(len(shapes), 2)
-    # Must compare shapes with as_list() since 2 equivalent non-fully defined
-    # shapes are not equal to each other.
-    self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
-    # Must compare `dims` since its an unknown shape.
-    self.assertEqual(shapes["output-2"].dims, None)
-
   def _assertValidSignature(self, inputs, outputs, method_name):
     signature_def = signature_def_utils_impl.build_signature_def(
         inputs, outputs, method_name)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 01d43e0..1c1a1a5 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -137,6 +137,7 @@
     size = "small",
     srcs = ["strip_unused_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["notap"],
     deps = [
         ":strip_unused_lib",
         "//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index 2810d83..271cf2a 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -12,10 +12,15 @@
     # END GENERATED ESTIMATOR FILES
 ]
 
+def get_compat_files(
+        file_paths,
+        compat_api_version):
+    """Prepends compat/v<compat_api_version> to file_paths."""
+    return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
+
 def gen_api_init_files(
         name,
         output_files = TENSORFLOW_API_INIT_FILES,
-        compat_output_files = {},
         root_init_template = None,
         srcs = [],
         api_name = "tensorflow",
@@ -23,7 +28,8 @@
         compat_api_versions = [],
         package = "tensorflow.python",
         package_dep = "//tensorflow/python:no_contrib",
-        output_package = "tensorflow"):
+        output_package = "tensorflow",
+        output_dir = ""):
     """Creates API directory structure and __init__.py files.
 
     Creates a genrule that generates a directory structure with __init__.py
@@ -37,8 +43,6 @@
         tf_export. For e.g. if an op is decorated with
         @tf_export('module1.module2', 'module3'). Then, output_files should
         include module1/module2/__init__.py and module3/__init__.py.
-      compat_output_files: Dictionary mapping each compat_api_version to the
-        set of __init__.py file paths that should be generated for that version.
       root_init_template: Python init file that should be used as template for
         root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
         template will be replaced with root imports collected by this genrule.
@@ -53,14 +57,16 @@
         process
       package_dep: Python library target containing your package.
       output_package: Package where generated API will be added to.
+      output_dir: Subdirectory to output API to.
+        If non-empty, must end with '/'.
     """
     root_init_template_flag = ""
     if root_init_template:
         root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
 
-    api_gen_binary_target = "create_" + package + "_api"
+    api_gen_binary_target = ("create_" + package + "_api_%d") % api_version
     native.py_binary(
-        name = "create_" + package + "_api",
+        name = api_gen_binary_target,
         srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
         main = "//tensorflow/python/tools/api/generator:create_python_api.py",
         srcs_version = "PY2AND3",
@@ -72,14 +78,9 @@
         ],
     )
 
-    all_output_files = list(output_files)
+    all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
     compat_api_version_flags = ""
     for compat_api_version in compat_api_versions:
-        compat_files = compat_output_files.get(compat_api_version, [])
-        all_output_files.extend([
-            "compat/v%d/%s" % (compat_api_version, f)
-            for f in compat_files
-        ])
         compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
 
     native.genrule(
@@ -87,12 +88,15 @@
         outs = all_output_files,
         cmd = (
             "$(location :" + api_gen_binary_target + ") " +
-            root_init_template_flag + " --apidir=$(@D) --apiname=" +
-            api_name + " --apiversion=" + str(api_version) +
+            root_init_template_flag + " --apidir=$(@D)" + output_dir +
+            " --apiname=" + api_name + " --apiversion=" + str(api_version) +
             compat_api_version_flags + " --package=" + package +
             " --output_package=" + output_package + " $(OUTS)"
         ),
         srcs = srcs,
         tools = [":" + api_gen_binary_target],
-        visibility = ["//tensorflow:__pkg__"],
+        visibility = [
+            "//tensorflow:__pkg__",
+            "//tensorflow/tools/api/tests:__pkg__",
+        ],
     )
diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py
index 988ecc6..97f4671 100644
--- a/tensorflow/python/tools/component_api_helper.py
+++ b/tensorflow/python/tools/component_api_helper.py
@@ -65,9 +65,10 @@
     Will allow the following import statement to work.
     >>> import parent.child
     """
-    child_pkg_path = [os.path.join(os.path.dirname(child_pkg.__file__), "..")]
+    child_pkg_path = [os.path.abspath(
+        os.path.join(os.path.dirname(child_pkg.__file__), ".."))]
     try:
-      parent_pkg.__path__ += child_pkg_path
+      parent_pkg.__path__ = child_pkg_path + parent_pkg.__path__
     except AttributeError:
       parent_pkg.__path__ = child_pkg_path
 
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py
index 4b3d982..cce8060 100644
--- a/tensorflow/python/tools/print_selective_registration_header_test.py
+++ b/tensorflow/python/tools/print_selective_registration_header_test.py
@@ -59,6 +59,9 @@
   }
 """
 
+# AccumulateNV2 is included because it should be included in the header despite
+# lacking a kernel (it's rewritten by AccumulateNV2RemovePass; see
+# core/common_runtime/accumulate_n_optimizer.cc.
 GRAPH_DEF_TXT_2 = """
   node: {
     name: "node_4"
@@ -67,6 +70,12 @@
     device: "/cpu:0"
     attr: { key: "T" value: { type: DT_FLOAT } }
   }
+  node: {
+    name: "node_5"
+    op: "AccumulateNV2"
+    attr: { key: "T" value: { type: DT_INT32 } }
+    attr: { key  : "N" value: { i: 3 } }
+  }
 
 """
 
@@ -100,6 +109,7 @@
 
     self.assertListEqual(
         [
+            ('AccumulateNV2', None),  #
             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
             ('MatMul',
              matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
@@ -117,6 +127,7 @@
         'rawproto', self.WriteGraphFiles(graphs), default_ops)
     self.assertListEqual(
         [
+            ('AccumulateNV2', None),  #
             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
             ('MatMul',
              matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
@@ -196,6 +207,7 @@
 
 constexpr inline bool ShouldRegisterOp(const char op[]) {
   return false
+     || isequal(op, "AccumulateNV2")
      || isequal(op, "BiasAdd")
   ;
 }
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 6716c79..d8ba13d8 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -33,7 +33,6 @@
 
 from six import integer_types
 from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
 from tensorflow.core.example import example_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.python.client import session
@@ -97,8 +96,7 @@
   Returns:
     A dictionary that maps input tensor keys to TensorInfos.
   """
-  return signature_def_utils.get_signature_def_by_key(meta_graph_def,
-                                                      signature_def_key).inputs
+  return meta_graph_def.signature_def[signature_def_key].inputs
 
 
 def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
@@ -116,8 +114,7 @@
   Returns:
     A dictionary that maps output tensor keys to TensorInfos.
   """
-  return signature_def_utils.get_signature_def_by_key(meta_graph_def,
-                                                      signature_def_key).outputs
+  return meta_graph_def.signature_def[signature_def_key].outputs
 
 
 def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
@@ -546,7 +543,7 @@
   input_examples = preprocess_input_examples_arg_string(input_examples_str)
 
   for input_tensor_key, (filename, variable_name) in inputs.items():
-    data = np.load(file_io.FileIO(filename, mode='r'))
+    data = np.load(file_io.FileIO(filename, mode='rb'))
 
     # When a variable_name key is specified for the input file
     if variable_name:
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index dc0612b..b99c632 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -32,6 +32,16 @@
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging
 
+# Usually, we use each graph node to induce registration of an op and
+# corresponding kernel; nodes without a corresponding kernel (perhaps due to
+# attr types) generate a warning but are otherwise ignored. Ops in this set are
+# registered even if there's no corresponding kernel.
+OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
+    # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
+    # core/common_runtime/accumulate_n_optimizer.cc.
+    'AccumulateNV2'
+])
+
 
 def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
   """Gets the ops and kernels needed from the model files."""
@@ -53,8 +63,10 @@
         node_def.device = '/cpu:0'
       kernel_class = pywrap_tensorflow.TryFindKernelClass(
           node_def.SerializeToString())
-      if kernel_class:
-        op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8')))
+      op = str(node_def.op)
+      if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
+        op_and_kernel = (op, str(kernel_class.decode('utf-8'))
+                         if kernel_class else None)
         if op_and_kernel not in ops:
           ops.add(op_and_kernel)
       else:
@@ -129,6 +141,7 @@
     '''
     line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
     for _, kernel_class in ops_and_kernels:
+      if kernel_class is None: continue
       line += '"%s",\n' % kernel_class
     line += '};'
     append(line)
diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py
index 2678016..a14ac89 100644
--- a/tensorflow/python/training/adadelta_test.py
+++ b/tensorflow/python/training/adadelta_test.py
@@ -155,7 +155,7 @@
                   rtol=1e-5)
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       self.doTestBasic(use_resource=False)
 
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -168,7 +168,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py
index c3a242a..00801be 100644
--- a/tensorflow/python/training/adagrad_da_test.py
+++ b/tensorflow/python/training/adagrad_da_test.py
@@ -34,7 +34,7 @@
 
   def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
     for dtype in [dtypes.float64, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         global_step = variables.Variable(0, dtype=dtypes.int64)
         if use_resource:
           var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -81,7 +81,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         global_step = resource_variable_ops.ResourceVariable(
             0, dtype=dtypes.int64)
@@ -101,7 +101,7 @@
 
   def testAdagradDAwithoutRegularizationBasic2(self):
     for dtype in [dtypes.float64, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         global_step = variables.Variable(0, dtype=dtypes.int64)
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -133,7 +133,7 @@
 
   def testAdagradDAWithL1(self):
     for dtype in [dtypes.float64, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         global_step = variables.Variable(0, dtype=dtypes.int64)
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
@@ -165,7 +165,7 @@
 
   def testAdagradDAWithL1_L2(self):
     for dtype in [dtypes.float64, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         global_step = variables.Variable(0, dtype=dtypes.int64)
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index 4e634ff..7caf01f 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -98,7 +98,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable(
             [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -117,7 +117,7 @@
 
   def testTensorLearningRate(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -141,7 +141,7 @@
 
   def testSparseBasic(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
         var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
         grads0 = ops.IndexedSlices(
@@ -172,7 +172,7 @@
 
   def testSparseRepeatedIndices(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         repeated_index_update_var = variables.Variable(
             [[1.0], [2.0]], dtype=dtype)
         aggregated_update_var = variables.Variable(
@@ -202,7 +202,7 @@
 
   def testSparseRepeatedIndicesResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var_repeated = resource_variable_ops.ResourceVariable(
             [1.0, 2.0], dtype=dtype)
         loss_repeated = math_ops.reduce_sum(
@@ -226,7 +226,7 @@
 
   def testSparseStability(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         shape = [1, 6]
         var0 = variables.Variable(
             [[
@@ -262,7 +262,7 @@
 
   def testSharing(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -295,7 +295,7 @@
             np.array([2.715679168701172, 3.715679168701172]), var1.eval())
 
   def testDynamicShapeVariable_Ok(self):
-    with self.test_session():
+    with self.cached_session():
       v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
                                       validate_shape=False)
       self.assertFalse(v.shape.is_fully_defined())
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 778c672..48db6e3 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -56,7 +56,7 @@
 
   def doTestSparse(self, use_resource=False):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -122,7 +122,7 @@
 
   def testSparseRepeatedIndices(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         repeated_index_update_var = variables.Variable(
             [[1.0], [2.0]], dtype=dtype)
         aggregated_update_var = variables.Variable(
@@ -224,7 +224,7 @@
                              opt.get_slot(var=var0, name="m").name)
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       self.doTestBasic(use_resource=False)
 
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -237,7 +237,7 @@
 
   def testTensorLearningRate(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -274,7 +274,7 @@
 
   def testSharing(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 7662562..3bd4bd7 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -1025,7 +1025,7 @@
 
   def before_run(self, run_context):
     self._request_summary = (
-        self._next_step is None or
+        self._next_step is not None and
         self._timer.should_trigger_for_step(self._next_step))
     requests = {"global_step": self._global_step_tensor}
     opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1035,6 +1035,10 @@
 
   def after_run(self, run_context, run_values):
     stale_global_step = run_values.results["global_step"]
+    if self._next_step is None:
+      # Update the timer so that it does not activate until N steps or seconds
+      # have passed.
+      self._timer.update_last_triggered_step(stale_global_step)
     global_step = stale_global_step + 1
     if self._request_summary:
       global_step = run_context.session.run(self._global_step_tensor)
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871..2d46963 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1145,7 +1145,7 @@
         summary_writer=self.summary_writer,
         summary_op=self.summary_op)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       hook.begin()
       sess.run(variables_lib.global_variables_initializer())
       mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1177,7 +1177,7 @@
         summary_writer=self.summary_writer,
         summary_op=[self.summary_op, self.summary_op2])
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       hook.begin()
       sess.run(variables_lib.global_variables_initializer())
       mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1205,7 +1205,7 @@
         summary_writer=self.summary_writer,
         summary_op=self.summary_op)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       hook.begin()
       sess.run(variables_lib.global_variables_initializer())
       mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1240,7 +1240,7 @@
         summary_writer=self.summary_writer,
         summary_op=self.summary_op)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       hook.begin()
       sess.run(variables_lib.global_variables_initializer())
       mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1388,7 +1388,7 @@
         summary_writer=self.summary_writer,
         summary_op=self.summary_op)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       hook.begin()
       sess.run(variables_lib.global_variables_initializer())
       mon_sess = monitored_session._HookedSession(sess, [hook])
@@ -1454,52 +1454,50 @@
     with self.assertRaises(ValueError):
       basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
 
-  def test_save_secs_saves_in_first_step(self):
+  def test_save_secs_does_not_save_in_first_step(self):
     with self.graph.as_default():
       hook = basic_session_run_hooks.ProfilerHook(
           save_secs=2, output_dir=self.output_dir)
       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
         sess.run(self.train_op)
-        self.assertEqual(1, self._count_timeline_files())
+        self.assertEqual(0, self._count_timeline_files())
 
   @test.mock.patch.object(time, 'time')
   def test_save_secs_saves_periodically(self, mock_time):
     # Pick a fixed start time.
-    current_time = 1484863632.320497
+    current_time = 1484863632.
 
     with self.graph.as_default():
       mock_time.return_value = current_time
       hook = basic_session_run_hooks.ProfilerHook(
           save_secs=2, output_dir=self.output_dir)
       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
-        sess.run(self.train_op)  # Saved.
-        self.assertEqual(1, self._count_timeline_files())
         sess.run(self.train_op)  # Not saved.
-        self.assertEqual(1, self._count_timeline_files())
+        self.assertEqual(0, self._count_timeline_files())
         # Simulate 2.5 seconds of sleep.
         mock_time.return_value = current_time + 2.5
         sess.run(self.train_op)  # Saved.
+        self.assertEqual(1, self._count_timeline_files())
 
         # Pretend some small amount of time has passed.
-        mock_time.return_value = current_time + 0.1
+        mock_time.return_value = current_time + 2.6
         sess.run(self.train_op)  # Not saved.
         # Edge test just before we should save the timeline.
-        mock_time.return_value = current_time + 1.9
+        mock_time.return_value = current_time + 4.4
         sess.run(self.train_op)  # Not saved.
-        self.assertEqual(2, self._count_timeline_files())
+        self.assertEqual(1, self._count_timeline_files())
 
         mock_time.return_value = current_time + 4.5
         sess.run(self.train_op)  # Saved.
-        self.assertEqual(3, self._count_timeline_files())
+        self.assertEqual(2, self._count_timeline_files())
 
-  def test_save_steps_saves_in_first_step(self):
+  def test_save_steps_does_not_save_in_first_step(self):
     with self.graph.as_default():
       hook = basic_session_run_hooks.ProfilerHook(
-          save_secs=2, output_dir=self.output_dir)
+          save_steps=1, output_dir=self.output_dir)
       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
-        sess.run(self.train_op)  # Saved.
         sess.run(self.train_op)  # Not saved.
-        self.assertEqual(1, self._count_timeline_files())
+        self.assertEqual(0, self._count_timeline_files())
 
   def test_save_steps_saves_periodically(self):
     with self.graph.as_default():
@@ -1507,6 +1505,8 @@
           save_steps=2, output_dir=self.output_dir)
       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
         self.assertEqual(0, self._count_timeline_files())
+        sess.run(self.train_op)  # Not saved.
+        self.assertEqual(0, self._count_timeline_files())
         sess.run(self.train_op)  # Saved.
         self.assertEqual(1, self._count_timeline_files())
         sess.run(self.train_op)  # Not saved.
@@ -1515,20 +1515,19 @@
         self.assertEqual(2, self._count_timeline_files())
         sess.run(self.train_op)  # Not saved.
         self.assertEqual(2, self._count_timeline_files())
-        sess.run(self.train_op)  # Saved.
-        self.assertEqual(3, self._count_timeline_files())
 
-  def test_run_metadata_saves_in_first_step(self):
+  def test_run_metadata_saves(self):
     writer_cache.FileWriterCache.clear()
     fake_summary_writer.FakeSummaryWriter.install()
     fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
     with self.graph.as_default():
       hook = basic_session_run_hooks.ProfilerHook(
-          save_secs=2, output_dir=self.output_dir)
+          save_steps=1, output_dir=self.output_dir)
       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+        sess.run(self.train_op)  # Not saved.
         sess.run(self.train_op)  # Saved.
         self.assertEqual(
-            list(fake_writer._added_run_metadata.keys()), ['step_1'])
+            list(fake_writer._added_run_metadata.keys()), ['step_2'])
     fake_summary_writer.FakeSummaryWriter.uninstall()
 
 
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 8ef5048..3a061bc 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -73,7 +73,7 @@
         # Collides with the default name of the checkpoint state file.
         filepath = os.path.join(traindir, "checkpoint")
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           unused_a = variables.Variable(0.0)  # So that Saver saves something.
           variables.global_variables_initializer().run()
 
@@ -113,7 +113,7 @@
         filename = "snapshot"
         filepath = os.path.join(traindir, filename)
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           # Build a simple graph.
           v0 = variables.Variable(0.0)
           inc = v0.assign_add(1.0)
@@ -128,7 +128,7 @@
           inc.eval()
           save.save(sess, filepath, global_step=2)
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           # Build a new graph with different initialization.
           v0 = variables.Variable(-1.0)
 
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
index a6e9662..cfd9b39 100644
--- a/tensorflow/python/training/checkpoint_ops.py
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -268,7 +268,8 @@
   vocab files are the same, and no column remapping is done.
 
   The returned initializer only supports div-partitioning along the row axis. It
-  does not support partitioning along the column axis or mod-partitioning.
+  does not support partitioning along the column axis (as this is not common in
+  practice) or mod-partitioning.
 
   NOTE: When this is used to warm-start variables, client code should use
   `tf.lookup.index_table_from_tensor()` like
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
index 00611de..dde8431 100644
--- a/tensorflow/python/training/checkpoint_ops_test.py
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -43,7 +43,7 @@
     # 0., 1., ..., 79. reshaped into [5, 16].
     initializer = init_ops.constant_initializer(
         np.reshape(np.linspace(0.0, 79, 5 * 16), (5, 16)))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with variable_scope.variable_scope('some_scope'):
         variable_scope.get_variable(name='embeddings', shape=[5, 16],
                                     initializer=initializer)
@@ -114,7 +114,7 @@
         ],
         axis=1)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
 
   def test_load_and_remap_output_layer_weight_initializer_linear(self):
@@ -150,7 +150,7 @@
         initializer=loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_matrix,
                           remapped_matrix.as_tensor().eval())
@@ -184,7 +184,7 @@
         initializer=loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_matrix,
                           remapped_matrix.as_tensor().eval())
@@ -222,7 +222,7 @@
         initializer=loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_matrix,
                           remapped_matrix.as_tensor().eval())
@@ -258,7 +258,7 @@
         initializer=loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_matrix,
                           remapped_matrix.as_tensor().eval())
@@ -292,7 +292,7 @@
         initializer=embedding_loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_embeddings,
                           remapped_embeddings.as_tensor().eval())
@@ -338,7 +338,7 @@
         initializer=embedding_loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_embeddings,
                           remapped_embeddings.as_tensor().eval())
@@ -376,7 +376,7 @@
         initializer=embedding_loading_initializer,
         partitioner=partitioned_variables.fixed_size_partitioner(2))
 
-    with self.test_session():
+    with self.cached_session():
       variables.global_variables_initializer().run()
       self.assertAllClose(expected_remapped_embeddings,
                           remapped_embeddings.as_tensor().eval())
diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py
index 1aab163..61dcbdb 100644
--- a/tensorflow/python/training/checkpoint_utils_test.py
+++ b/tensorflow/python/training/checkpoint_utils_test.py
@@ -84,7 +84,7 @@
 
   def testNoTensor(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
     with self.assertRaises(errors_impl.OpError):
       self.assertAllEqual(
@@ -92,7 +92,7 @@
 
   def testGetTensor(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -105,7 +105,7 @@
 
   def testGetAllVariables(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _create_checkpoints(session, checkpoint_dir)
     self.assertEqual(
         checkpoint_utils.list_variables(checkpoint_dir),
@@ -114,7 +114,7 @@
 
   def testInitFromCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -148,7 +148,7 @@
 
   def testInitialValueComesFromCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -178,7 +178,7 @@
 
   def testInitWithScopeDoesNotCaptureSuffixes(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
 
     with ops.Graph().as_default() as g:
@@ -197,7 +197,7 @@
 
   def testRestoreRunsOnSameDevice(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _create_checkpoints(session, checkpoint_dir)
 
     with ops.Graph().as_default():
@@ -213,7 +213,7 @@
 
   def testInitFromRootCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -237,7 +237,7 @@
 
   def testInitToRootCheckpoint(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -260,7 +260,7 @@
 
   def testInitFromPartitionVar(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1 = _create_partition_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -322,7 +322,7 @@
 
   def testInitFromCheckpointMissing(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
@@ -367,7 +367,7 @@
 
   def testNoAdditionalReadOpsForResourceVariables(self):
     checkpoint_dir = self.get_temp_dir()
-    with self.test_session() as session:
+    with self.cached_session() as session:
       v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)
 
     # New graph and session.
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9189d8f..095a90d 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,14 @@
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import collections
 import functools
 import json
 import weakref
 
+import six
+
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
@@ -91,7 +94,45 @@
     return self._checkpoint_position
 
 
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+  """Embeds a tensor in a checkpoint with no restore ops."""
+
+  def __init__(self, tensor, name, dtype=None):
+    spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+    super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+  def restore(self, restored_tensors, restored_shapes):
+    return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+  """An interface for saving/restoring volatile Python state."""
+
+  @abc.abstractmethod
+  def feed_dict_additions(self):
+    """When running a graph, indicates fresh state to feed.
+
+    Returns:
+      A dictionary mapping `Tensor`s to current Python state.
+    """
+    pass
+
+  @abc.abstractmethod
+  def freeze(self):
+    """Create a new `SaveableObject` which freezes current state as a constant.
+
+    Used when executing eagerly to embed the current state as a constant, or
+    when creating a static tf.train.Saver with the frozen current Python state.
+
+    Returns:
+      A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+      no Python state associated with it).
+    """
+    pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
   """Saves Python state in a checkpoint."""
 
   def __init__(self, name, state_callback, restore_callback=None):
@@ -104,19 +145,26 @@
       restore_callback: A function taking a Python string, used to restore
         state. Optional; defaults to doing nothing.
     """
+    self._state_callback = state_callback
     self._restore_callback = restore_callback
-    if context.executing_eagerly():
-      self._save_string = (
-          lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
-    else:
+    with ops.device("/cpu:0"):
       self._save_string = constant_op.constant("", dtype=dtypes.string)
-      self.feed_dict_additions = (
-          lambda: {self._save_string: state_callback()})
     spec = saveable_object.SaveSpec(
         self._save_string, "", name, dtype=dtypes.string)
     super(PythonStringStateSaveable, self).__init__(
         self._save_string, [spec], name)
 
+  def feed_dict_additions(self):
+    """When running a graph, indicates fresh state to feed."""
+    return {self._save_string: self._state_callback()}
+
+  def freeze(self):
+    """Create a frozen `SaveableObject` which saves the current state."""
+    return NoRestoreSaveable(
+        tensor=self._state_callback,
+        dtype=dtypes.string,
+        name=self.name)
+
   def python_restore(self, restored_strings):
     """Called to restore Python state."""
     if self._restore_callback:
@@ -309,7 +357,7 @@
         if self._checkpoint.saveable_object_cache is not None:
           self._checkpoint.saveable_object_cache.setdefault(
               self.checkpointable, {})[serialized_tensor.name] = [saveable]
-      if isinstance(saveable, PythonStringStateSaveable):
+      if isinstance(saveable, PythonStateSaveable):
         python_saveables.append(saveable)
       else:
         named_saveables[serialized_tensor.checkpoint_key] = saveable
@@ -819,7 +867,7 @@
     def _state_callback():
       dereferenced_self = weak_self()
       if dereferenced_self:
-        return json.dumps(self,
+        return json.dumps(dereferenced_self,
                           default=serialization.get_json_type,
                           sort_keys=True).encode("utf8")
       else:
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index f06cbbf..c29e5db 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -18,6 +18,7 @@
 from __future__ import print_function
 
 import collections
+import copy
 
 import six
 
@@ -251,6 +252,12 @@
       self._storage[index] = self._track_value(
           element, name=self._name_element(index))
 
+  def __copy__(self):
+    return type(self)(copy.copy(self._storage))
+
+  def __deepcopy__(self, memo):
+    return type(self)(copy.deepcopy(self._storage, memo))
+
   def _make_storage(self, *args, **kwargs):
     """Determines the backing storage (overridden in subclasses)."""
     return list(*args, **kwargs)
@@ -325,6 +332,20 @@
     super(_ListWrapper, self).__init__(wrapped_list)
     self._last_wrapped_list_snapshot = list(self._storage)
 
+  # pylint: disable=protected-access
+  def __copy__(self):
+    copied = super(_ListWrapper, self).__copy__()
+    copied._non_append_mutation = self._non_append_mutation
+    copied._external_modification = self._external_modification
+    return copied
+
+  def __deepcopy__(self, memo):
+    copied = super(_ListWrapper, self).__deepcopy__(memo)
+    copied._non_append_mutation = self._non_append_mutation
+    copied._external_modification = self._external_modification
+    return copied
+  # pylint: enable=protected-access
+
   def _make_storage(self, wrapped_list):
     """Use the user's original list for storage."""
     return wrapped_list
@@ -449,6 +470,12 @@
             value, name=self._name_element(key))
          for key, value in self._storage.items()})
 
+  def __copy__(self):
+    return type(self)(copy.copy(self._storage))
+
+  def __deepcopy__(self, memo):
+    return type(self)(copy.deepcopy(self._storage, memo))
+
   def _make_storage(self, *args, **kwargs):
     return dict(*args, **kwargs)
 
@@ -525,6 +552,22 @@
     super(_DictWrapper, self).__init__(wrapped_dict)
     self._update_snapshot()
 
+  # pylint: disable=protected-access
+  def __copy__(self):
+    copied = super(_DictWrapper, self).__copy__()
+    copied._non_append_mutation = self._non_append_mutation
+    copied._external_modification = self._external_modification
+    copied._non_string_key = self._non_string_key
+    return copied
+
+  def __deepcopy__(self, memo):
+    copied = super(_DictWrapper, self).__deepcopy__(memo)
+    copied._non_append_mutation = self._non_append_mutation
+    copied._external_modification = self._external_modification
+    copied._non_string_key = self._non_string_key
+    return copied
+  # pylint: enable=protected-access
+
   def _make_storage(self, wrapped_dict):
     """Re-use the wrapped dict for storage (to force them to be in sync)."""
     return wrapped_dict
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 4638917..5597c7c 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -16,6 +16,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import copy
 import os
 
 import numpy
@@ -424,6 +425,104 @@
     new_dict.update(model.d)
     self.assertEqual({1: 3}, new_dict)
 
+  def testListShallowCopy(self):
+    root = tracking.Checkpointable()
+    orig_list = [[1.]]
+    root.a = orig_list
+    copied = copy.copy(root.a)
+    self.assertAllEqual([[1.]], copied)
+    self.assertIsNot(root.a, copied)
+    self.assertIs(root.a[0], copied[0])
+
+    # Dirtiness should be inherited
+    util.list_objects(root.a)
+    orig_list.append(1.)
+    with self.assertRaises(ValueError):
+      util.list_objects(root.a)
+    with self.assertRaises(ValueError):
+      util.list_objects(copy.copy(root.a))
+
+  def testListDeepCopy(self):
+    root = tracking.Checkpointable()
+    orig_list = [[1.]]
+    root.a = orig_list
+    copied = copy.deepcopy(root.a)
+    self.assertAllEqual([[1.]], copied)
+    self.assertIsNot(root.a, copied)
+    self.assertIsNot(root.a[0], copied[0])
+
+    # Dirtiness should be inherited
+    util.list_objects(root.a)
+    orig_list.append(1.)
+    with self.assertRaises(ValueError):
+      util.list_objects(root.a)
+    with self.assertRaises(ValueError):
+      util.list_objects(copy.deepcopy(root.a))
+
+  def testDictShallowCopy(self):
+    root = tracking.Checkpointable()
+    orig_dict = {"a": [1.]}
+    root.a = orig_dict
+    copied = copy.copy(root.a)
+    self.assertAllEqual([1.], copied["a"])
+    self.assertIsNot(root.a, copied)
+    self.assertIs(root.a["a"], copied["a"])
+
+    # Dirtiness should be inherited
+    util.list_objects(root.a)
+    orig_dict["b"] = []
+    with self.assertRaises(ValueError):
+      util.list_objects(root.a)
+    with self.assertRaises(ValueError):
+      util.list_objects(copy.copy(root.a))
+
+  def testDictDeepCopy(self):
+    root = tracking.Checkpointable()
+    orig_dict = {"a": [1.]}
+    root.a = orig_dict
+    copied = copy.deepcopy(root.a)
+    self.assertAllEqual([1.], copied["a"])
+    self.assertIsNot(root.a, copied)
+    self.assertIsNot(root.a["a"], copied["a"])
+
+    # Dirtiness should be inherited
+    util.list_objects(root.a)
+    orig_dict["b"] = []
+    with self.assertRaises(ValueError):
+      util.list_objects(root.a)
+    with self.assertRaises(ValueError):
+      util.list_objects(copy.deepcopy(root.a))
+
+  def testShallowCopyCheckpointable(self):
+    original = tracking.Checkpointable()
+    original_sub = tracking.Checkpointable()
+    original.a = [[1.]]
+    original.b = {"a": original_sub}
+    shallow_copied = copy.copy(original)
+    self.assertIs(original_sub, shallow_copied.b["a"])
+    self.assertIsNot(original, shallow_copied)
+    self.assertEqual([[1.]], shallow_copied.a)
+    shallow_deps = util.list_objects(shallow_copied)
+    self.assertIn(shallow_copied.a, shallow_deps)
+    self.assertIn(shallow_copied.b, shallow_deps)
+    self.assertIn(shallow_copied.b["a"], shallow_deps)
+
+  def testDeepCopyCheckpointable(self):
+    original = tracking.Checkpointable()
+    original_sub = tracking.Checkpointable()
+    original.a = [[1.]]
+    original.b = {"a": original_sub}
+    deep_copied = copy.deepcopy(original)
+    self.assertIsNot(original, deep_copied)
+    self.assertIsNot(original_sub, deep_copied.b["a"])
+    self.assertEqual([[1.]], deep_copied.a)
+    self.assertIsInstance(deep_copied.b["a"], tracking.Checkpointable)
+    deps = util.list_objects(deep_copied)
+    self.assertIn(deep_copied.a, deps)
+    self.assertIn(deep_copied.b, deps)
+    self.assertIn(deep_copied.b["a"], deps)
+    self.assertNotIn(original_sub, deps)
+
   def testConstructableFromSequence(self):
     result = data_structures._DictWrapper([(1, 2), (3, 4)])
     self.assertIsInstance(result, dict)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index e85f812..a44c570 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -165,7 +165,7 @@
     self.assertEqual([c], a.attribute["c"].layers)
     checkpoint = util.Checkpoint(a=a)
     save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
-    with self.test_session():
+    with self.cached_session():
       checkpoint.restore(save_path).assert_consumed().initialize_or_restore()
 
   @test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 13dddd3..56c4043 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_io_ops as io_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import variable_scope
@@ -557,7 +556,14 @@
   object_graph_proto = (
       checkpointable_object_graph_pb2.CheckpointableObjectGraph())
   named_saveables = []
-  feed_additions = {}
+  if saveables_cache is None:
+    # No SaveableObject caching. Either we're executing eagerly, or building a
+    # static save which is specialized to the current Python state.
+    feed_additions = None
+  else:
+    # If we are caching SaveableObjects, we need to build up a feed_dict with
+    # functions computing volatile Python state to be saved with the checkpoint.
+    feed_additions = {}
   for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
     assert node_ids[checkpointable] == checkpoint_id
     object_proto = object_graph_proto.nodes.add()
@@ -616,18 +622,25 @@
       for saveable in saveables:
         if hasattr(saveable, "full_name"):
           attribute.full_name = saveable.full_name
-        saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
-        if saveable_feed_dict_fn is not None:
-          saveable_feed_dict = saveable_feed_dict_fn()  # pylint: disable=not-callable
-          for new_feed_key in saveable_feed_dict.keys():
-            if new_feed_key in feed_additions:
-              raise AssertionError(
-                  ("The object %s tried to feed a value for the Tensor %s "
-                   "when saving, but another object is already feeding a "
-                   "value.")
-                  % (checkpointable, new_feed_key))
-          feed_additions.update(saveable_feed_dict)
-      named_saveables.extend(saveables)
+        if isinstance(saveable, base.PythonStateSaveable):
+          if feed_additions is None:
+            assert saveables_cache is None
+            # If we're not caching saveables, then we're either executing
+            # eagerly or building a static save/restore (e.g. for a
+            # SavedModel). In either case, we should embed the current Python
+            # state in the graph rather than relying on a feed dict.
+            saveable = saveable.freeze()
+          else:
+            saveable_feed_dict = saveable.feed_dict_additions()
+            for new_feed_key in saveable_feed_dict.keys():
+              if new_feed_key in feed_additions:
+                raise AssertionError(
+                    ("The object %s tried to feed a value for the Tensor %s "
+                     "when saving, but another object is already feeding a "
+                     "value.")
+                    % (checkpointable, new_feed_key))
+            feed_additions.update(saveable_feed_dict)
+        named_saveables.append(saveable)
 
     for child in checkpointable._checkpoint_dependencies:  # pylint: disable=protected-access
       child_proto = object_proto.children.add()
@@ -827,16 +840,6 @@
     yield
 
 
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
-  def __init__(self, tensor, name):
-    spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
-    super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
-  def restore(self, restored_tensors, restored_shapes):
-    return control_flow_ops.no_op()
-
-
 class _LoadStatus(object):
   """Abstract base for load status callbacks."""
 
@@ -1241,6 +1244,78 @@
     else:
       return self._root_checkpointable_ref
 
+  def _gather_saveables(
+      self, object_graph_tensor=None, saveable_object_cache=None):
+    """Wraps _serialize_object_graph to include the object graph proto."""
+    assert ((object_graph_tensor is None and saveable_object_cache is None)
+            or (object_graph_tensor is not None
+                and saveable_object_cache is not None))
+    (named_saveable_objects, graph_proto,
+     feed_additions) = _serialize_object_graph(
+         self._root_checkpointable,
+         saveables_cache=saveable_object_cache)
+    if object_graph_tensor is None:
+      with ops.device("/cpu:0"):
+        object_graph_tensor = constant_op.constant(
+            graph_proto.SerializeToString(), dtype=dtypes.string)
+    else:
+      feed_additions.update(
+          {object_graph_tensor: graph_proto.SerializeToString()})
+    assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+    named_saveable_objects.append(
+        base.NoRestoreSaveable(
+            tensor=object_graph_tensor,
+            name=base.OBJECT_GRAPH_PROTO_KEY))
+    return named_saveable_objects, graph_proto, feed_additions
+
+  def freeze(self):
+    """Creates a `tf.train.Saver` with the current object graph frozen."""
+    named_saveable_objects, _, _ = self._gather_saveables(
+        object_graph_tensor=None, saveable_object_cache=None)
+    return saver_lib.Saver(
+        var_list=named_saveable_objects, max_to_keep=None)
+
+  def _prepare_save(self,
+                    object_graph_tensor=None,
+                    saveable_object_cache=None):
+    """Create or retrieve save ops.
+
+    When graph building, `saveable_object_cache` will typically be non-`None`,
+    meaning that existing `SaveableObject`s are re-used across calls to
+    `_prepare_save` even if the object graph has grown. This avoids
+    unnecessarily re-creating save ops.
+
+    Args:
+      object_graph_tensor: A `Tensor` to which the current object graph will be
+        fed.
+      saveable_object_cache: A dictionary; if specified, used to cache
+        `SaveableObject`s.
+
+    Returns:
+      A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+      to feed when running save ops. The feed dict contains the current object
+      graph and any Python state to be saved in the checkpoint.
+    """
+    (named_saveable_objects, graph_proto,
+     feed_additions) = self._gather_saveables(
+         object_graph_tensor=object_graph_tensor,
+         saveable_object_cache=saveable_object_cache)
+    if (self._last_save_object_graph != graph_proto
+        # When executing eagerly, we need to re-create SaveableObjects each time
+        # save() is called so they pick up new Tensors passed to their
+        # constructors. That means the Saver needs to be copied with a new
+        # var_list.
+        or context.executing_eagerly()):
+      if self._last_save_object_graph is not None:
+        self._last_save_saver = _copy_saver_with_new_var_list(
+            old_saver=self._last_save_saver,
+            new_var_list=named_saveable_objects)
+      else:
+        self._last_save_saver = saver_lib.Saver(
+            var_list=named_saveable_objects, max_to_keep=None)
+      self._last_save_object_graph = graph_proto
+    return self._last_save_saver, feed_additions
+
   def save(self, file_prefix, checkpoint_number=None, session=None):
     """Save a training checkpoint.
 
@@ -1263,44 +1338,29 @@
     Returns:
       The full path to the checkpoint.
     """
-    named_variables, graph_proto, feed_additions = _serialize_object_graph(
-        self._root_checkpointable,
-        saveables_cache=self._saveable_object_cache)
-    if not context.executing_eagerly():
-      if session is None:
-        session = ops.get_default_session()
+    feed_additions = {}
+    graph_building = not context.executing_eagerly()
+    if graph_building:
       if self._object_graph_feed_tensor is None:
         with ops.device("/cpu:0"):
           self._object_graph_feed_tensor = constant_op.constant(
               "", dtype=dtypes.string)
       object_graph_tensor = self._object_graph_feed_tensor
-      feed_additions.update(
-          {object_graph_tensor: graph_proto.SerializeToString()})
     else:
+      object_graph_tensor = None
+
+    saver, new_feed_additions = self._prepare_save(
+        object_graph_tensor=object_graph_tensor,
+        saveable_object_cache=self._saveable_object_cache)
+    if new_feed_additions:
+      feed_additions.update(new_feed_additions)
+    if not graph_building:
       session = None
-      with ops.device("/cpu:0"):
-        object_graph_tensor = constant_op.constant(
-            graph_proto.SerializeToString(), dtype=dtypes.string)
-    assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
-    named_variables.append(
-        _NoRestoreSaveable(
-            tensor=object_graph_tensor,
-            name=base.OBJECT_GRAPH_PROTO_KEY))
-    if (self._last_save_object_graph != graph_proto
-        # When executing eagerly, we need to re-create SaveableObjects each time
-        # save() is called so they pick up new Tensors passed to their
-        # constructors. That means the Saver needs to be copied with a new
-        # var_list.
-        or context.executing_eagerly()):
-      if self._last_save_object_graph is not None:
-        self._last_save_saver = _copy_saver_with_new_var_list(
-            old_saver=self._last_save_saver, new_var_list=named_variables)
-      else:
-        self._last_save_saver = saver_lib.Saver(
-            var_list=named_variables, max_to_keep=None)
-      self._last_save_object_graph = graph_proto
+    elif session is None:
+      session = ops.get_default_session()
+
     with ops.device("/cpu:0"):
-      save_path = self._last_save_saver.save(
+      save_path = saver.save(
           sess=_SessionWithFeedDictAdditions(
               session=session, feed_additions=feed_additions),
           save_path=file_prefix,
@@ -1422,6 +1482,30 @@
     return load_status
 
 
+def frozen_saver(root_checkpointable):
+  """Creates a static `tf.train.Saver` from a checkpointable object.
+
+  The returned `Saver` saves object-based checkpoints, but these checkpoints
+  will no longer reflect structural changes to the object graph, only changes to
+  the values of `Variable`s added as dependencies of the root object before
+  `freeze` was called.
+
+  `restore` works on the returned `Saver`, but requires that the object graph of
+  the checkpoint being loaded exactly matches the object graph when `freeze` was
+  called. This is in contrast the object-based restore performed by
+  `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+  object graph and the current Python object graph.
+
+  Args:
+    root_checkpointable: A checkpointable object to save.
+
+  Returns:
+    A `tf.train.Saver` which saves object-based checkpoints for the object graph
+    frozen at the time `frozen_saver` was called.
+  """
+  return CheckpointableSaver(root_checkpointable).freeze()
+
+
 @tf_export("train.Checkpoint")
 class Checkpoint(tracking.Checkpointable):
   """Groups checkpointable objects, saving and restoring them.
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index bef4bf2..f8b5bd8 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -384,7 +384,7 @@
     saver = saver_lib.Saver(var_list=[v])
     test_dir = self.get_temp_dir()
     prefix = os.path.join(test_dir, "ckpt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.evaluate(v.non_dep_variable.assign(42.))
       save_path = saver.save(sess, prefix)
       self.evaluate(v.non_dep_variable.assign(43.))
@@ -560,6 +560,46 @@
                          self.evaluate(root.save_counter))
 
   @test_util.run_in_graph_and_eager_modes
+  def testFreezing(self):
+    with self.cached_session(use_gpu=True) as session:
+      # Save an object-based checkpoint using a frozen saver
+      directory = self.get_temp_dir()
+      prefix = os.path.join(directory, "ckpt")
+      v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+      checkpoint = checkpointable_utils.Checkpoint(v=v)
+      self.evaluate(v.assign(3))
+      # Create the save counter so assert_consumed doesn't complain about it not
+      # existing in the checkpoint on restore.
+      self.evaluate(checkpoint.save_counter.assign(12))
+      saver = checkpointable_utils.frozen_saver(checkpoint)
+      save_path = saver.save(session, prefix)
+      self.evaluate(v.assign(10))
+      # Use the frozen saver to restore the same object graph
+      saver.restore(session, save_path)
+      self.assertEqual(3, self.evaluate(v))
+
+      # Restore using another frozen saver on an identical object graph
+      del v, checkpoint, saver
+      v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+      checkpoint = checkpointable_utils.Checkpoint(v=v)
+      saver = checkpointable_utils.frozen_saver(checkpoint)
+      saver.restore(session, save_path)
+      self.assertEqual(3, self.evaluate(v))
+
+      # Restore as an object-based checkpoint
+      del v, checkpoint, saver
+      checkpoint = checkpointable_utils.Checkpoint()
+      status = checkpoint.restore(save_path)
+      v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+      if context.executing_eagerly():
+        self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+        self.assertEqual(0, self.evaluate(v))
+      checkpoint.v = v
+      status.assert_consumed().run_restore_ops()
+      self.assertEqual(3, self.evaluate(v))
+      self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+
+  @test_util.run_in_graph_and_eager_modes
   def testCustomNumbering(self):
     directory = self.get_temp_dir()
     prefix = os.path.join(directory, "ckpt")
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 76ca5b4..09d6fe3 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -37,7 +37,7 @@
 
   def doTestFtrlwithoutRegularization(self, use_resource=False):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         if use_resource:
           var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
           var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
@@ -76,7 +76,7 @@
 
   def testFtrlwithoutRegularization2(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -105,7 +105,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -121,7 +121,7 @@
 
   def testFtrlWithL1(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -150,7 +150,7 @@
 
   def testFtrlWithL1_L2(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -186,7 +186,7 @@
     weights will tend to have smaller magnitudes with this parameter set.
     """
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([4.0, 3.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
@@ -335,7 +335,7 @@
   # FTRL-Proximal performs same updates as Adagrad or GradientDescent.
   def testEquivAdagradwithoutRegularization(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session():
+      with self.cached_session():
         val0, val1 = self.applyOptimizer(
             ftrl.FtrlOptimizer(
                 3.0,
@@ -346,7 +346,7 @@
                 l2_regularization_strength=0.0),
             dtype)
 
-      with self.test_session():
+      with self.cached_session():
         val2, val3 = self.applyOptimizer(
             adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
 
@@ -355,7 +355,7 @@
 
   def testEquivSparseAdagradwithoutRegularization(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session():
+      with self.cached_session():
         val0, val1 = self.applyOptimizer(
             ftrl.FtrlOptimizer(
                 3.0,
@@ -367,7 +367,7 @@
             dtype,
             is_sparse=True)
 
-      with self.test_session():
+      with self.cached_session():
         val2, val3 = self.applyOptimizer(
             adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
             dtype,
@@ -378,7 +378,7 @@
 
   def testEquivSparseGradientDescentwithoutRegularization(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session():
+      with self.cached_session():
         val0, val1 = self.applyOptimizer(
             ftrl.FtrlOptimizer(
                 3.0,
@@ -390,7 +390,7 @@
             dtype,
             is_sparse=True)
 
-      with self.test_session():
+      with self.cached_session():
         val2, val3 = self.applyOptimizer(
             gradient_descent.GradientDescentOptimizer(3.0),
             dtype,
@@ -401,7 +401,7 @@
 
   def testEquivGradientDescentwithoutRegularization(self):
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session():
+      with self.cached_session():
         val0, val1 = self.applyOptimizer(
             ftrl.FtrlOptimizer(
                 3.0,
@@ -412,7 +412,7 @@
                 l2_regularization_strength=0.0),
             dtype)
 
-      with self.test_session():
+      with self.cached_session():
         val2, val3 = self.applyOptimizer(
             gradient_descent.GradientDescentOptimizer(3.0), dtype)
 
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index b304e92..56d82a5 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -37,7 +37,7 @@
 
   def testBasic(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -60,7 +60,7 @@
 
   def testBasicResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -85,7 +85,7 @@
 
   def testBasicCallableParams(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -111,7 +111,7 @@
 
   def testMinimizeResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -137,7 +137,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -164,7 +164,7 @@
 
   def testTensorLearningRate(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -186,7 +186,7 @@
 
   def testGradWrtRef(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         opt = gradient_descent.GradientDescentOptimizer(3.0)
         values = [1.0, 3.0]
         vars_ = [variables.Variable([v], dtype=dtype) for v in values]
@@ -197,7 +197,7 @@
 
   def testWithGlobalStep(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         global_step = variables.Variable(0, trainable=False)
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
@@ -220,7 +220,7 @@
 
   def testSparseBasic(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
         var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
         grads0 = ops.IndexedSlices(
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 0d6207f..9d9db70 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -45,6 +45,7 @@
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.summary import summary
 from tensorflow.python.training import queue_runner
+from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
 
@@ -75,7 +76,10 @@
         collections=[ops.GraphKeys.LOCAL_VARIABLES])
 
 
-@tf_export("train.limit_epochs")
+@tf_export(v1=["train.limit_epochs"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.")
 def limit_epochs(tensor, num_epochs=None, name=None):
   """Returns tensor `num_epochs` times and then raises an `OutOfRange` error.
 
@@ -108,7 +112,12 @@
       return array_ops.identity(tensor, name=name)
 
 
-@tf_export("train.input_producer")
+@tf_export(v1=["train.input_producer"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.from_tensor_slices(input_tensor).shuffle"
+    "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+    "`shuffle=False`, omit the `.shuffle(...)`.")
 def input_producer(input_tensor,
                    element_shape=None,
                    num_epochs=None,
@@ -191,7 +200,12 @@
     return q
 
 
-@tf_export("train.string_input_producer")
+@tf_export(v1=["train.string_input_producer"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.from_tensor_slices(string_tensor).shuffle"
+    "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+    "`shuffle=False`, omit the `.shuffle(...)`.")
 def string_input_producer(string_tensor,
                           num_epochs=None,
                           shuffle=True,
@@ -261,7 +275,11 @@
         cancel_op=cancel_op)
 
 
-@tf_export("train.range_input_producer")
+@tf_export(v1=["train.range_input_producer"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.range(limit).shuffle(limit).repeat(num_epochs)`. If "
+    "`shuffle=False`, omit the `.shuffle(...)`.")
 def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
                          capacity=32, shared_name=None, name=None):
   """Produces the integers from 0 to limit-1 in a queue.
@@ -299,7 +317,12 @@
         shared_name, "fraction_of_%d_full" % capacity, name)
 
 
-@tf_export("train.slice_input_producer")
+@tf_export(v1=["train.slice_input_producer"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.from_tensor_slices(tuple(tensor_list)).shuffle"
+    "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If "
+    "`shuffle=False`, omit the `.shuffle(...)`.")
 def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
                          capacity=32, shared_name=None, name=None):
   """Produces a slice of each `Tensor` in `tensor_list`.
@@ -894,7 +917,11 @@
 # Batching functions ----------------------------------------------------------
 
 
-@tf_export("train.batch")
+@tf_export(v1=["train.batch"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.batch(batch_size)` (or `padded_batch(...)` if "
+    "`dynamic_pad=True`).")
 def batch(tensors, batch_size, num_threads=1, capacity=32,
           enqueue_many=False, shapes=None, dynamic_pad=False,
           allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -989,7 +1016,11 @@
       name=name)
 
 
-@tf_export("train.maybe_batch")
+@tf_export(v1=["train.maybe_batch"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.filter(...).batch(batch_size)` (or `padded_batch(...)`"
+    " if `dynamic_pad=True`).")
 def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
                 enqueue_many=False, shapes=None, dynamic_pad=False,
                 allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1042,7 +1073,11 @@
       name=name)
 
 
-@tf_export("train.batch_join")
+@tf_export(v1=["train.batch_join"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.interleave(...).batch(batch_size)` (or "
+    "`padded_batch(...)` if `dynamic_pad=True`).")
 def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
                shapes=None, dynamic_pad=False, allow_smaller_final_batch=False,
                shared_name=None, name=None):
@@ -1148,7 +1183,11 @@
       name=name)
 
 
-@tf_export("train.maybe_batch_join")
+@tf_export(v1=["train.maybe_batch_join"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.interleave(...).filter(...).batch(batch_size)` (or "
+    "`padded_batch(...)` if `dynamic_pad=True`).")
 def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
                      enqueue_many=False, shapes=None, dynamic_pad=False,
                      allow_smaller_final_batch=False, shared_name=None,
@@ -1201,7 +1240,10 @@
       name=name)
 
 
-@tf_export("train.shuffle_batch")
+@tf_export(v1=["train.shuffle_batch"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.shuffle(min_after_dequeue).batch(batch_size)`.")
 def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                   num_threads=1, seed=None, enqueue_many=False, shapes=None,
                   allow_smaller_final_batch=False, shared_name=None, name=None):
@@ -1301,7 +1343,11 @@
       name=name)
 
 
-@tf_export("train.maybe_shuffle_batch")
+@tf_export(v1=["train.maybe_shuffle_batch"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.filter(...).shuffle(min_after_dequeue).batch(batch_size)`"
+    ".")
 def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                         keep_input, num_threads=1, seed=None,
                         enqueue_many=False, shapes=None,
@@ -1361,7 +1407,11 @@
       name=name)
 
 
-@tf_export("train.shuffle_batch_join")
+@tf_export(v1=["train.shuffle_batch_join"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.interleave(...).shuffle(min_after_dequeue).batch"
+    "(batch_size)`.")
 def shuffle_batch_join(tensors_list, batch_size, capacity,
                        min_after_dequeue, seed=None, enqueue_many=False,
                        shapes=None, allow_smaller_final_batch=False,
@@ -1455,7 +1505,11 @@
       name=name)
 
 
-@tf_export("train.maybe_shuffle_batch_join")
+@tf_export(v1=["train.maybe_shuffle_batch_join"])
+@deprecation.deprecated(
+    None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
+    "`tf.data.Dataset.interleave(...).filter(...).shuffle(min_after_dequeue)"
+    ".batch(batch_size)`.")
 def maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
                              min_after_dequeue, keep_input, seed=None,
                              enqueue_many=False, shapes=None,
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 1b1e89c..a9b05dc 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -51,7 +51,7 @@
     for name in additional:
       open(name, "w").write("Some contents")
     filenames = list(set(filenames + additional))
-    with self.test_session():
+    with self.cached_session():
       star = inp.match_filenames_once(os.path.join(self.get_temp_dir(), "*"))
       question = inp.match_filenames_once(
           os.path.join(self.get_temp_dir(), "match_filenames.?"))
@@ -66,7 +66,7 @@
 class LimitEpochsTest(test_lib.TestCase):
 
   def testNoLimit(self):
-    with self.test_session():
+    with self.cached_session():
       seven = constant_op.constant(7)
       seven_forever = inp.limit_epochs(seven)
       variables.local_variables_initializer().run()
@@ -74,7 +74,7 @@
         self.assertEqual(7, seven_forever.eval())
 
   def testLimit(self):
-    with self.test_session():
+    with self.cached_session():
       love_me = constant_op.constant("Love Me")
       love_me_two_times = inp.limit_epochs(love_me, num_epochs=2)
       variables.global_variables_initializer().run()
@@ -88,7 +88,7 @@
 class InputProducerTest(test_lib.TestCase):
 
   def testNoShuffle(self):
-    with self.test_session():
+    with self.cached_session():
       input_tensor = [[1, 2, 3, 4],
                       [5, 6, 7, 8],
                       [9, 10, 11, 12]]
@@ -111,7 +111,7 @@
         thread.join()
 
   def testNoShapeInference(self):
-    with self.test_session():
+    with self.cached_session():
       # Disable shape inference for the input.
       input_value = [[1, 2, 3, 4],
                      [5, 6, 7, 8],
@@ -144,7 +144,7 @@
 class StringInputProducerTest(test_lib.TestCase):
 
   def testNoShuffle(self):
-    with self.test_session():
+    with self.cached_session():
       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
       num_epochs = 3
       queue = inp.string_input_producer(
@@ -166,7 +166,7 @@
         thread.join()
 
   def testShuffle(self):
-    with self.test_session():
+    with self.cached_session():
       strings = [b"a", b"b", b"c"]
       num_epochs = 600
       queue = inp.string_input_producer(
@@ -206,7 +206,7 @@
 
   def testNullStringPython(self):
     # Graph-construction time check for empty string list:
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         _ = inp.string_input_producer([])
 
@@ -214,7 +214,7 @@
     # Runtime check for empty string list.  This is slightly oblique:
     # The queue runner should die with an assertion error on the null
     # input tensor, causing the dequeue to fail with an OutOfRangeError.
-    with self.test_session():
+    with self.cached_session():
       coord = coordinator.Coordinator()
       queue = inp.string_input_producer(
           constant_op.constant(
@@ -230,7 +230,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
       queue = inp.string_input_producer(
           strings, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -238,7 +238,7 @@
                              queue.queue_ref.op.node_def.attr["shared_name"])
 
   def testConstructionRace(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       strings = [b"to", b"be", b"or", b"not", b"to", b"be"]
       queue = inp.string_input_producer(strings, shuffle=False)
       coord = coordinator.Coordinator()
@@ -260,7 +260,7 @@
 class RangeInputProducerTest(test_lib.TestCase):
 
   def testNoShuffle(self):
-    with self.test_session():
+    with self.cached_session():
       num_epochs = 3
       range_size = 5
       queue = inp.range_input_producer(
@@ -282,7 +282,7 @@
         thread.join()
 
   def testShuffle(self):
-    with self.test_session():
+    with self.cached_session():
       num_epochs = 200
       range_size = 2
       queue = inp.range_input_producer(
@@ -321,7 +321,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       range_size = 5
       queue = inp.range_input_producer(
           range_size, shared_name="SHARED_NAME_XYZ", name="Q")
@@ -332,7 +332,7 @@
 class SliceInputProducerTest(test_lib.TestCase):
 
   def testNoShuffle(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_epochs = 3
       source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"]
       source_ints = [2, 3, 5, 7]
@@ -356,7 +356,7 @@
         thread.join()
 
   def testShuffle(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       num_epochs = 1200
       source_strings = ["A", "B", "D", "G"]
       source_ints = [7, 3, 5, 2]
@@ -400,7 +400,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       source_strings = ["A", "B", "D", "G"]
       source_ints = [7, 3, 5, 2]
       slices = inp.slice_input_producer(
@@ -440,7 +440,7 @@
 class BatchTest(test_lib.TestCase):
 
   def _testOneThreadHelper(self, use_dict):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -500,7 +500,7 @@
   def testUint32DataTypes(self):
     values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint32)
     batched = inp.batch([values], batch_size=2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
       sess.run(batched)
@@ -511,7 +511,7 @@
   def testUint64DataTypes(self):
     values = constant_op.constant([0, 1, 2, 3, 4, 5], dtype=dtypes.uint64)
     batched = inp.batch([values], batch_size=2)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
       sess.run(batched)
@@ -520,7 +520,7 @@
         thread.join()
 
   def testOneThreadDynamicPad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -550,7 +550,7 @@
         thread.join()
 
   def testOneThreadEnqueueMany(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -585,7 +585,7 @@
         thread.join()
 
   def testManyThreads(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -625,7 +625,7 @@
         thread.join()
 
   def testOneThreadSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       extra_elements = 5
@@ -682,7 +682,7 @@
         thread.join()
 
   def testManyThreadsSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       extra_elements = 5
@@ -737,7 +737,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -754,7 +754,7 @@
           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
 
   def testCannotInferRankError(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtype=dtypes.int64)
       with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
         inp.batch([x], batch_size=2)
@@ -797,7 +797,7 @@
 
   def _testKeepInputHelper(self, num_threads, enqueue_many,
                            keep_input_vector=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 5
       num_batches = 4
       examples = variables.Variable(0)
@@ -934,7 +934,7 @@
     batched = inp.maybe_batch(
         [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
 
-    with self.test_session():
+    with self.cached_session():
       coord = coordinator.Coordinator()
       threads = queue_runner_impl.start_queue_runners(coord=coord)
 
@@ -952,7 +952,7 @@
 class BatchJoinTest(test_lib.TestCase):
 
   def _testTwoThreadsHelper(self, use_dict):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Two threads, the first generates (0..69, "a").
       num_a = 70
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1069,7 +1069,7 @@
           batch_size=8)
 
   def DISABLED_testTwoThreadsDynamicPad(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Two threads, the first generates (0..69, ["a"] * 1..70).
       num_a = 70
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1144,7 +1144,7 @@
         thread.join()
 
   def DISABLED_testTwoThreadsSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       extra_elements = 2
       # Two threads, the first generates (0..69, "a").
       num_a = 70 + extra_elements
@@ -1243,7 +1243,7 @@
         thread.join()
 
   def DISABLED_testTwoThreadsDynamicPadSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       extra_elements = 2
       # Two threads, the first generates (0..69, ["a"] * 1..70).
       num_a = 70 + extra_elements
@@ -1338,7 +1338,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1360,7 +1360,7 @@
           batched[0].op.inputs[0].op.node_def.attr["shared_name"])
 
   def testCannotInferRankError(self):
-    with self.test_session():
+    with self.cached_session():
       x = array_ops.placeholder(dtype=dtypes.int64)
       with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
         inp.batch_join([[x]], batch_size=2)
@@ -1371,7 +1371,7 @@
 
   def _testKeepInputHelper(self, num_threads, enqueue_many,
                            keep_input_vector=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 5
       num_batches = 4
       examples = variables.Variable(0)
@@ -1511,7 +1511,7 @@
     batched = inp.maybe_batch_join(
         [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
 
-    with self.test_session():
+    with self.cached_session():
       coord = coordinator.Coordinator()
       threads = queue_runner_impl.start_queue_runners(coord=coord)
 
@@ -1529,7 +1529,7 @@
 class ShuffleBatchTest(test_lib.TestCase):
 
   def _testOneThreadHelper(self, use_dict):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1594,7 +1594,7 @@
     self._testOneThreadHelper(use_dict=True)
 
   def testOneThreadSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       extra_elements = 5
@@ -1650,7 +1650,7 @@
         thread.join()
 
   def testManyThreads(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1697,7 +1697,7 @@
         thread.join()
 
   def testManyThreadsSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 10
       num_batches = 3
       extra_elements = 5
@@ -1755,7 +1755,7 @@
         thread.join()
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -1775,7 +1775,7 @@
 
   def _testKeepInputHelper(self, num_threads, enqueue_many,
                            keep_input_vector=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 5
       num_batches = 4
       examples = variables.Variable(0)
@@ -1906,7 +1906,7 @@
 class ShuffleBatchJoinTest(test_lib.TestCase):
 
   def _testTwoThreadsHelper(self, use_dict):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Two threads, the first generates (0..24, "a").
       num_a = 25
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2017,7 +2017,7 @@
     self._testTwoThreadsHelper(use_dict=True)
 
   def testTwoThreadsSmallerBatch(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Two threads, the first generates (0..26, "a").
       extra_elements = 2
       num_a = 25 + extra_elements
@@ -2137,7 +2137,7 @@
           seed=223607)
 
   def testSharedName(self):
-    with self.test_session():
+    with self.cached_session():
       batch_size = 10
       num_batches = 3
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
@@ -2162,7 +2162,7 @@
 
   def _testKeepInputHelper(self, num_threads, enqueue_many,
                            keep_input_vector=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       batch_size = 5
       num_batches = 4
       examples = variables.Variable(0)
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index fd195a7..29b5465 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -17,19 +17,12 @@
 from __future__ import division
 from __future__ import print_function
 
-import math
-
 from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.training import learning_rate_decay_v2
 from tensorflow.python.util.tf_export import tf_export
 
 
-@tf_export("train.exponential_decay")
+@tf_export(v1=["train.exponential_decay"])
 def exponential_decay(learning_rate,
                       global_step,
                       decay_steps,
@@ -95,32 +88,19 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("global_step is required for exponential_decay.")
-  with ops.name_scope(
-      name, "ExponentialDecay",
-      [learning_rate, global_step, decay_steps, decay_rate]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
-    decay_rate = math_ops.cast(decay_rate, dtype)
+  decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate,
+                                                        global_step,
+                                                        decay_steps,
+                                                        decay_rate,
+                                                        staircase=staircase,
+                                                        name=name)
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      p = global_step_recomp / decay_steps
-      if staircase:
-        p = math_ops.floor(p)
-      return math_ops.multiply(
-          learning_rate, math_ops.pow(decay_rate, p), name=name)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.piecewise_constant")
+@tf_export(v1=["train.piecewise_constant"])
 def piecewise_constant(x, boundaries, values, name=None):
   """Piecewise constant from boundaries and interval values.
 
@@ -163,58 +143,15 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if len(boundaries) != len(values) - 1:
-    raise ValueError(
-        "The length of boundaries should be 1 less than the length of values")
-  with ops.name_scope(name, "PiecewiseConstant",
-                      [x, boundaries, values, name]) as name:
-    boundaries = ops.convert_n_to_tensor(boundaries)
-    values = ops.convert_n_to_tensor(values)
+  decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values,
+                                                         name=name)
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      x_recomp = ops.convert_to_tensor(x)
-      # Avoid explicit conversion to x's dtype. This could result in faulty
-      # comparisons, for example if floats are converted to integers.
-      for i, b in enumerate(boundaries):
-        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
-          # We can promote int32 boundaries to int64 without loss of precision.
-          # This covers the most common case where the user passes in boundaries
-          # as an array of Python integers.
-          if (b.dtype.base_dtype == dtypes.int32 and
-              x_recomp.dtype.base_dtype == dtypes.int64):
-            b = math_ops.cast(b, x_recomp.dtype.base_dtype)
-            boundaries[i] = b
-          else:
-            raise ValueError(
-                "Boundaries (%s) must have the same dtype as x (%s)." %
-                (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
-      # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
-      for v in values[1:]:
-        if v.dtype.base_dtype != values[0].dtype.base_dtype:
-          raise ValueError(
-              "Values must have elements all with the same dtype (%s vs %s)." %
-              (values[0].dtype.base_dtype, v.dtype.base_dtype))
-      pred_fn_pairs = []
-      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
-      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
-      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
-        # Need to bind v here; can do this with lambda v=v: ...
-        pred = (x_recomp > low) & (x_recomp <= high)
-        pred_fn_pairs.append((pred, lambda v=v: v))
-
-      # The default isn't needed here because our conditions are mutually
-      # exclusive and exhaustive, but tf.case requires it.
-      default = lambda: values[0]
-      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.polynomial_decay")
+@tf_export(v1=["train.polynomial_decay"])
 def polynomial_decay(learning_rate,
                      global_step,
                      decay_steps,
@@ -299,46 +236,22 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("global_step is required for polynomial_decay.")
-  with ops.name_scope(
-      name, "PolynomialDecay",
-      [learning_rate, global_step, decay_steps, end_learning_rate, power
-      ]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    end_learning_rate = math_ops.cast(end_learning_rate, dtype)
-    power = math_ops.cast(power, dtype)
+  decayed_lr = learning_rate_decay_v2.polynomial_decay(
+      learning_rate,
+      global_step,
+      decay_steps,
+      end_learning_rate=end_learning_rate,
+      power=power,
+      cycle=cycle,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      decay_steps_recomp = math_ops.cast(decay_steps, dtype)
-      if cycle:
-        # Find the first multiple of decay_steps that is bigger than
-        # global_step. If global_step is zero set the multiplier to 1
-        multiplier = control_flow_ops.cond(
-            math_ops.equal(global_step_recomp, 0), lambda: 1.0,
-            lambda: math_ops.ceil(global_step_recomp / decay_steps))
-        decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
-      else:
-        # Make sure that the global_step used is not bigger than decay_steps.
-        global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-      p = math_ops.div(global_step_recomp, decay_steps_recomp)
-      return math_ops.add(
-          math_ops.multiply(learning_rate - end_learning_rate,
-                            math_ops.pow(1 - p, power)),
-          end_learning_rate,
-          name=name)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.natural_exp_decay")
+@tf_export(v1=["train.natural_exp_decay"])
 def natural_exp_decay(learning_rate,
                       global_step,
                       decay_steps,
@@ -410,32 +323,17 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("global_step is required for natural_exp_decay.")
-  with ops.name_scope(name, "NaturalExpDecay",
-                      [learning_rate, global_step, decay_rate]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
-    decay_rate = math_ops.cast(decay_rate, dtype)
+  decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+      learning_rate, global_step, decay_steps, decay_rate, staircase=staircase,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      p = global_step_recomp / decay_steps
-      if staircase:
-        p = math_ops.floor(p)
-      exponent = math_ops.exp(
-          math_ops.multiply(math_ops.negative(decay_rate), p))
-      return math_ops.multiply(learning_rate, exponent, name=name)
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.inverse_time_decay")
+@tf_export(v1=["train.inverse_time_decay"])
 def inverse_time_decay(learning_rate,
                        global_step,
                        decay_steps,
@@ -507,32 +405,21 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("global_step is required for inverse_time_decay.")
-  with ops.name_scope(name, "InverseTimeDecay",
-                      [learning_rate, global_step, decay_rate]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
-    decay_rate = math_ops.cast(decay_rate, dtype)
+  decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+      learning_rate,
+      global_step,
+      decay_steps,
+      decay_rate,
+      staircase=staircase,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      p = global_step_recomp / decay_steps
-      if staircase:
-        p = math_ops.floor(p)
-      const = math_ops.cast(constant_op.constant(1), dtype)
-      denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
-      return math_ops.div(learning_rate, denom, name=name)
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.cosine_decay")
+@tf_export(v1=["train.cosine_decay"])
 def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
   """Applies cosine decay to the learning rate.
 
@@ -581,32 +468,16 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("cosine decay requires global_step")
-  with ops.name_scope(name, "CosineDecay",
-                      [learning_rate, global_step]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
+  decayed_lr = learning_rate_decay_v2.cosine_decay(
+      learning_rate, global_step, decay_steps, alpha=alpha, name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
-      completed_fraction = global_step_recomp / decay_steps
-      cosine_decayed = 0.5 * (1.0 + math_ops.cos(
-          constant_op.constant(math.pi) * completed_fraction))
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-      decayed = (1 - alpha) * cosine_decayed + alpha
-      return math_ops.multiply(learning_rate, decayed)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.cosine_decay_restarts")
+@tf_export(v1=["train.cosine_decay_restarts"])
 def cosine_decay_restarts(learning_rate,
                           global_step,
                           first_decay_steps,
@@ -664,57 +535,22 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("cosine decay restarts requires global_step")
-  with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name:
-    learning_rate = ops.convert_to_tensor(
-        learning_rate, name="initial_learning_rate")
-    dtype = learning_rate.dtype
-    first_decay_steps = math_ops.cast(first_decay_steps, dtype)
-    alpha = math_ops.cast(alpha, dtype)
-    t_mul = math_ops.cast(t_mul, dtype)
-    m_mul = math_ops.cast(m_mul, dtype)
+  decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+      learning_rate,
+      global_step,
+      first_decay_steps,
+      t_mul=t_mul,
+      m_mul=m_mul,
+      alpha=alpha,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      completed_fraction = global_step_recomp / first_decay_steps
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-      def compute_step(completed_fraction, geometric=False):
-        """Helper for `cond` operation."""
-        if geometric:
-          i_restart = math_ops.floor(
-              math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
-              math_ops.log(t_mul))
-
-          sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
-          completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
-
-        else:
-          i_restart = math_ops.floor(completed_fraction)
-          completed_fraction -= i_restart
-
-        return i_restart, completed_fraction
-
-      i_restart, completed_fraction = control_flow_ops.cond(
-          math_ops.equal(t_mul, 1.0),
-          lambda: compute_step(completed_fraction, geometric=False),
-          lambda: compute_step(completed_fraction, geometric=True))
-
-      m_fac = m_mul**i_restart
-      cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
-          constant_op.constant(math.pi) * completed_fraction))
-      decayed = (1 - alpha) * cosine_decayed + alpha
-
-      return math_ops.multiply(learning_rate, decayed, name=name)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.linear_cosine_decay")
+@tf_export(v1=["train.linear_cosine_decay"])
 def linear_cosine_decay(learning_rate,
                         global_step,
                         decay_steps,
@@ -781,37 +617,22 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("linear cosine decay requires global_step")
-  with ops.name_scope(name, "LinearCosineDecay",
-                      [learning_rate, global_step]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
-    num_periods = math_ops.cast(num_periods, dtype)
-    alpha = math_ops.cast(alpha, dtype)
-    beta = math_ops.cast(beta, dtype)
+  decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+      learning_rate,
+      global_step,
+      decay_steps,
+      num_periods=num_periods,
+      alpha=alpha,
+      beta=beta,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
-      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
-      completed_fraction = global_step_recomp / decay_steps
-      fraction = 2.0 * num_periods * completed_fraction
-      cosine_decayed = 0.5 * (
-          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-      linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
-      return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
 
 
-@tf_export("train.noisy_linear_cosine_decay")
+@tf_export(v1=["train.noisy_linear_cosine_decay"])
 def noisy_linear_cosine_decay(learning_rate,
                               global_step,
                               decay_steps,
@@ -886,42 +707,17 @@
   the learning rate value across different invocations of optimizer functions.
   @end_compatibility
   """
-  if global_step is None:
-    raise ValueError("noisy linear cosine decay requires global_step")
-  with ops.name_scope(name, "NoisyLinearCosineDecay",
-                      [learning_rate, global_step]) as name:
-    learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
-    dtype = learning_rate.dtype
-    decay_steps = math_ops.cast(decay_steps, dtype)
-    initial_variance = math_ops.cast(initial_variance, dtype)
-    variance_decay = math_ops.cast(variance_decay, dtype)
-    num_periods = math_ops.cast(num_periods, dtype)
-    alpha = math_ops.cast(alpha, dtype)
-    beta = math_ops.cast(beta, dtype)
+  decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+      learning_rate, global_step,
+      decay_steps,
+      initial_variance=initial_variance,
+      variance_decay=variance_decay,
+      num_periods=num_periods,
+      alpha=alpha,
+      beta=beta,
+      name=name)
 
-    def decayed_lr():
-      """Helper to recompute learning rate; most helpful in eager-mode."""
-      global_step_recomp = math_ops.cast(global_step, dtype)
-      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
-      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
-      variance = initial_variance / (
-          math_ops.pow(1.0 + global_step_recomp, variance_decay))
-      std = math_ops.sqrt(variance)
-      noisy_linear_decayed = (
-          linear_decayed + random_ops.random_normal(
-              linear_decayed.shape, stddev=std))
+  if not context.executing_eagerly():
+    decayed_lr = decayed_lr()
 
-      completed_fraction = global_step_recomp / decay_steps
-      fraction = 2.0 * num_periods * completed_fraction
-      cosine_decayed = 0.5 * (
-          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
-      noisy_linear_cosine_decayed = (
-          (alpha + noisy_linear_decayed) * cosine_decayed + beta)
-
-      return math_ops.multiply(
-          learning_rate, noisy_linear_cosine_decayed, name=name)
-
-    if not context.executing_eagerly():
-      decayed_lr = decayed_lr()
-
-    return decayed_lr
+  return decayed_lr
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 4f3cf01..5a92157 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -62,7 +62,7 @@
       self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
 
   def testVariables(self):
-    with self.test_session():
+    with self.cached_session():
       step = variables.Variable(1)
       assign_1 = step.assign(1)
       assign_2 = step.assign(2)
diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py
new file mode 100644
index 0000000..9c5e144
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2.py
@@ -0,0 +1,898 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various learning rate decay functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("train.exponential_decay", v1=[])
+def exponential_decay(learning_rate,
+                      global_step,
+                      decay_steps,
+                      decay_rate,
+                      staircase=False,
+                      name=None):
+  """Applies exponential decay to the learning rate.
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies an exponential decay function
+  to a provided initial learning rate.  It requires a `global_step` value to
+  compute the decayed learning rate.  You can just pass a TensorFlow variable
+  that you increment at each training step.
+
+  The function returns a no-arg function that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions.
+  It is computed as:
+
+  ```python
+  decayed_learning_rate = learning_rate *
+                          decay_rate ^ (global_step / decay_steps)
+  ```
+
+  If the argument `staircase` is `True`, then `global_step / decay_steps` is an
+  integer division and the decayed learning rate follows a staircase function.
+
+  Example: decay every 100000 steps with a base of 0.96:
+
+  ```python
+  ...
+  global_step = tf.Variable(0, trainable=False)
+  starter_learning_rate = 0.1
+  learning_rate_fn = tf.train.exponential_decay(starter_learning_rate,
+                                                global_step, 100000, 0.96,
+                                                staircase=True)
+  # Passing global_step to minimize() will increment it at each step.
+  learning_step = (
+      tf.train.GradientDescentOptimizer(learning_rate_fn)
+      .minimize(...my loss..., global_step=global_step)
+  )
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.  Must not be negative.
+    decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Must be positive.  See the decay computation above.
+    decay_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The decay rate.
+    staircase: Boolean.  If `True` decay the learning rate at discrete intervals
+    name: String.  Optional name of the operation.  Defaults to
+      'ExponentialDecay'.
+
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("global_step is required for exponential_decay.")
+  def decayed_lr(learning_rate, global_step, decay_steps, decay_rate,
+                 staircase, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(
+        name, "ExponentialDecay",
+        [learning_rate, global_step, decay_steps, decay_rate]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+      decay_rate = math_ops.cast(decay_rate, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      p = global_step_recomp / decay_steps
+      if staircase:
+        p = math_ops.floor(p)
+      return math_ops.multiply(
+          learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           decay_rate, staircase, name)
+
+
+@tf_export("train.piecewise_constant", v1=[])
+def piecewise_constant(x, boundaries, values, name=None):
+  """Piecewise constant from boundaries and interval values.
+
+  This function returns a no-arg callable to compute the piecewise constant.
+  This can be useful for changing the learning rate value across
+  different invocations of optimizer functions.
+
+  Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
+    for the next 10000 steps, and 0.1 for any additional steps.
+
+  ```python
+  global_step = tf.Variable(0, trainable=False)
+  boundaries = [100000, 110000]
+  values = [1.0, 0.5, 0.1]
+  learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries,
+    values)
+  learning_rate = learning_rate_fn()
+
+  # Later, whenever we perform an optimization step, we increment global_step.
+  ```
+
+  Args:
+    x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
+      `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
+    boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
+      increasing entries, and with all elements having the same type as `x`.
+    values: A list of `Tensor`s or `float`s or `int`s that specifies the values
+      for the intervals defined by `boundaries`. It should have one more element
+      than `boundaries`, and all elements should have the same type.
+    name: A string. Optional name of the operation. Defaults to
+      'PiecewiseConstant'.
+
+  Returns:
+    A no-arg function that outputs a 0-D Tensor. The output of the no-arg
+    function is `values[0]` when `x <= boundaries[0]`,
+    `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
+    and values[-1] when `x > boundaries[-1]`.
+
+  Raises:
+    ValueError: if types of `x` and `boundaries` do not match, or types of all
+        `values` do not match or
+        the number of elements in the lists does not match.
+  """
+  if len(boundaries) != len(values) - 1:
+    raise ValueError(
+        "The length of boundaries should be 1 less than the length of values")
+  def decayed_lr(x, boundaries, values, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "PiecewiseConstant",
+                        [x, boundaries, values, name]) as name:
+      boundaries = ops.convert_n_to_tensor(boundaries)
+      values = ops.convert_n_to_tensor(values)
+      x_recomp = ops.convert_to_tensor(x)
+      # Avoid explicit conversion to x's dtype. This could result in faulty
+      # comparisons, for example if floats are converted to integers.
+      for i, b in enumerate(boundaries):
+        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
+          # We can promote int32 boundaries to int64 without loss of precision.
+          # This covers the most common case where the user passes in boundaries
+          # as an array of Python integers.
+          if (b.dtype.base_dtype == dtypes.int32 and
+              x_recomp.dtype.base_dtype == dtypes.int64):
+            b = math_ops.cast(b, x_recomp.dtype.base_dtype)
+            boundaries[i] = b
+          else:
+            raise ValueError(
+                "Boundaries (%s) must have the same dtype as x (%s)." %
+                (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
+      # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
+      for v in values[1:]:
+        if v.dtype.base_dtype != values[0].dtype.base_dtype:
+          raise ValueError(
+              "Values must have elements all with the same dtype (%s vs %s)." %
+              (values[0].dtype.base_dtype, v.dtype.base_dtype))
+      pred_fn_pairs = []
+      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
+      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
+      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+        # Need to bind v here; can do this with lambda v=v: ...
+        pred = (x_recomp > low) & (x_recomp <= high)
+        pred_fn_pairs.append((pred, lambda v=v: v))
+
+      # The default isn't needed here because our conditions are mutually
+      # exclusive and exhaustive, but tf.case requires it.
+      default = lambda: values[0]
+      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+  return functools.partial(decayed_lr, x, boundaries, values, name)
+
+
+@tf_export("train.polynomial_decay", v1=[])
+def polynomial_decay(learning_rate,
+                     global_step,
+                     decay_steps,
+                     end_learning_rate=0.0001,
+                     power=1.0,
+                     cycle=False,
+                     name=None):
+  """Applies a polynomial decay to the learning rate.
+
+  It is commonly observed that a monotonically decreasing learning rate, whose
+  degree of change is carefully chosen, results in a better performing model.
+  This function applies a polynomial decay function to a provided initial
+  `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`.
+
+  It requires a `global_step` value to compute the decayed learning rate.  You
+  can just pass a TensorFlow variable that you increment at each training step.
+
+  The function returns a no-arg callable that outputs the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  global_step = min(global_step, decay_steps)
+  decayed_learning_rate = (learning_rate - end_learning_rate) *
+                          (1 - global_step / decay_steps) ^ (power) +
+                          end_learning_rate
+
+  ```
+
+  If `cycle` is True then a multiple of `decay_steps` is used, the first one
+  that is bigger than `global_steps`.
+
+  ```python
+  decay_steps = decay_steps * ceil(global_step / decay_steps)
+  decayed_learning_rate_fn = (learning_rate - end_learning_rate) *
+                          (1 - global_step / decay_steps) ^ (power) +
+                          end_learning_rate
+  decayed_learning_rate = decayed_learning_rate_fn()
+
+  ```
+
+  Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5):
+
+  ```python
+  ...
+  global_step = tf.Variable(0, trainable=False)
+  starter_learning_rate = 0.1
+  end_learning_rate = 0.01
+  decay_steps = 10000
+  learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate,
+                                               global_step, decay_steps,
+                                               end_learning_rate,
+                                               power=0.5)
+  # Passing global_step to minimize() will increment it at each step.
+  learning_step = (
+      tf.train.GradientDescentOptimizer(learning_rate_fn)
+      .minimize(...my loss..., global_step=global_step)
+  )
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.  Must not be negative.
+    decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Must be positive.  See the decay computation above.
+    end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The minimal end learning rate.
+    power: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The power of the polynomial. Defaults to linear, 1.0.
+    cycle: A boolean, whether or not it should cycle beyond decay_steps.
+    name: String.  Optional name of the operation. Defaults to
+      'PolynomialDecay'.
+
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("global_step is required for polynomial_decay.")
+  def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate,
+                 power, cycle, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(
+        name, "PolynomialDecay",
+        [learning_rate, global_step, decay_steps, end_learning_rate, power]
+    ) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      end_learning_rate = math_ops.cast(end_learning_rate, dtype)
+      power = math_ops.cast(power, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      decay_steps_recomp = math_ops.cast(decay_steps, dtype)
+      if cycle:
+        # Find the first multiple of decay_steps that is bigger than
+        # global_step. If global_step is zero set the multiplier to 1
+        multiplier = control_flow_ops.cond(
+            math_ops.equal(global_step_recomp, 0), lambda: 1.0,
+            lambda: math_ops.ceil(global_step_recomp / decay_steps))
+        decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
+      else:
+        # Make sure that the global_step used is not bigger than decay_steps.
+        global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+
+      p = math_ops.div(global_step_recomp, decay_steps_recomp)
+      return math_ops.add(
+          math_ops.multiply(learning_rate - end_learning_rate,
+                            math_ops.pow(1 - p, power)),
+          end_learning_rate,
+          name=name)
+
+  return functools.partial(
+      decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate,
+      power, cycle, name)
+
+
+@tf_export("train.natural_exp_decay", v1=[])
+def natural_exp_decay(learning_rate,
+                      global_step,
+                      decay_steps,
+                      decay_rate,
+                      staircase=False,
+                      name=None):
+  """Applies natural exponential decay to the initial learning rate.
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies an exponential decay function
+  to a provided initial learning rate.  It requires an `global_step` value to
+  compute the decayed learning rate.  You can just pass a TensorFlow variable
+  that you increment at each training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
+  decay_step)
+  ```
+
+  or, if `staircase` is `True`, as:
+
+  ```python
+  decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
+  decay_step))
+  ```
+
+  Example: decay exponentially with a base of 0.96:
+
+  ```python
+  ...
+  global_step = tf.Variable(0, trainable=False)
+  learning_rate = 0.1
+  decay_steps = 5
+  k = 0.5
+  learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step,
+                                                decay_steps, k)
+
+  # Passing global_step to minimize() will increment it at each step.
+  learning_step = (
+      tf.train.GradientDescentOptimizer(learning_rate_fn)
+      .minimize(...my loss..., global_step=global_step)
+  )
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The initial learning rate.
+    global_step: A Python number.
+      Global step to use for the decay computation.  Must not be negative.
+    decay_steps: How often to apply decay.
+    decay_rate: A Python number.  The decay rate.
+    staircase: Whether to apply decay in a discrete staircase, as opposed to
+      continuous, fashion.
+    name: String.  Optional name of the operation.  Defaults to
+      'ExponentialTimeDecay'.
+
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("global_step is required for natural_exp_decay.")
+  def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+                 name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "NaturalExpDecay",
+                        [learning_rate, global_step, decay_rate]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+      decay_rate = math_ops.cast(decay_rate, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      p = global_step_recomp / decay_steps
+      if staircase:
+        p = math_ops.floor(p)
+      exponent = math_ops.exp(
+          math_ops.multiply(math_ops.negative(decay_rate), p))
+      return math_ops.multiply(learning_rate, exponent, name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           decay_rate, staircase, name)
+
+
+@tf_export("train.inverse_time_decay", v1=[])
+def inverse_time_decay(learning_rate,
+                       global_step,
+                       decay_steps,
+                       decay_rate,
+                       staircase=False,
+                       name=None):
+  """Applies inverse time decay to the initial learning rate.
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies an inverse decay function
+  to a provided initial learning rate.  It requires an `global_step` value to
+  compute the decayed learning rate.  You can just pass a TensorFlow variable
+  that you increment at each training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /
+  decay_step)
+  ```
+
+  or, if `staircase` is `True`, as:
+
+  ```python
+  decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step /
+  decay_step))
+  ```
+
+  Example: decay 1/t with a rate of 0.5:
+
+  ```python
+  ...
+  global_step = tf.Variable(0, trainable=False)
+  learning_rate = 0.1
+  decay_steps = 1.0
+  decay_rate = 0.5
+  learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step,
+  decay_steps, decay_rate)
+
+  # Passing global_step to minimize() will increment it at each step.
+  learning_step = (
+      tf.train.GradientDescentOptimizer(learning_rate_fn)
+      .minimize(...my loss..., global_step=global_step)
+  )
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` `Tensor` or a
+      Python number.  The initial learning rate.
+    global_step: A Python number.
+      Global step to use for the decay computation.  Must not be negative.
+    decay_steps: How often to apply decay.
+    decay_rate: A Python number.  The decay rate.
+    staircase: Whether to apply decay in a discrete staircase, as opposed to
+      continuous, fashion.
+    name: String.  Optional name of the operation.  Defaults to
+      'InverseTimeDecay'.
+
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("global_step is required for inverse_time_decay.")
+  def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
+                 name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "InverseTimeDecay",
+                        [learning_rate, global_step, decay_rate]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+      decay_rate = math_ops.cast(decay_rate, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      p = global_step_recomp / decay_steps
+      if staircase:
+        p = math_ops.floor(p)
+      const = math_ops.cast(constant_op.constant(1), dtype)
+      denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
+      return math_ops.div(learning_rate, denom, name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           decay_rate, staircase, name)
+
+
+@tf_export("train.cosine_decay", v1=[])
+def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0,
+                 name=None):
+  """Applies cosine decay to the learning rate.
+
+  See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+  with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies a cosine decay function
+  to a provided initial learning rate.  It requires a `global_step` value to
+  compute the decayed learning rate.  You can just pass a TensorFlow variable
+  that you increment at each training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  global_step = min(global_step, decay_steps)
+  cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
+  decayed = (1 - alpha) * cosine_decay + alpha
+  decayed_learning_rate = learning_rate * decayed
+  ```
+
+  Example usage:
+  ```python
+  decay_steps = 1000
+  lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps)
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+      The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.
+    decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Number of steps to decay over.
+    alpha: A scalar `float32` or `float64` Tensor or a Python number.
+      Minimum learning rate value as a fraction of learning_rate.
+    name: String. Optional name of the operation.  Defaults to 'CosineDecay'.
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("cosine decay requires global_step")
+  def decayed_lr(learning_rate, global_step, decay_steps, alpha, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "CosineDecay",
+                        [learning_rate, global_step]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+      completed_fraction = global_step_recomp / decay_steps
+      cosine_decayed = 0.5 * (1.0 + math_ops.cos(
+          constant_op.constant(math.pi) * completed_fraction))
+
+      decayed = (1 - alpha) * cosine_decayed + alpha
+      return math_ops.multiply(learning_rate, decayed)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           alpha, name)
+
+
+@tf_export("train.cosine_decay_restarts", v1=[])
+def cosine_decay_restarts(learning_rate,
+                          global_step,
+                          first_decay_steps,
+                          t_mul=2.0,
+                          m_mul=1.0,
+                          alpha=0.0,
+                          name=None):
+  """Applies cosine decay with restarts to the learning rate.
+
+  See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
+  with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies a cosine decay function with
+  restarts to a provided initial learning rate.  It requires a `global_step`
+  value to compute the decayed learning rate.  You can just pass a TensorFlow
+  variable that you increment at each training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate while taking into account possible warm restarts. This can be useful for
+  changing the learning rate value across different invocations of optimizer
+  functions.
+
+  The learning rate multiplier first decays
+  from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
+  restart is performed. Each new warm restart runs for `t_mul` times more steps
+  and with `m_mul` times smaller initial learning rate.
+
+  Example usage:
+  ```python
+  first_decay_steps = 1000
+  lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step,
+                                     first_decay_steps)
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+      The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.
+    first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Number of steps to decay over.
+    t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+      Used to derive the number of iterations in the i-th period
+    m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+      Used to derive the initial learning rate of the i-th period:
+    alpha: A scalar `float32` or `float64` Tensor or a Python number.
+      Minimum learning rate value as a fraction of the learning_rate.
+    name: String. Optional name of the operation.  Defaults to 'SGDRDecay'.
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("cosine decay restarts requires global_step")
+  def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul,
+                 alpha, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]
+                       ) as name:
+      learning_rate = ops.convert_to_tensor(
+          learning_rate, name="initial_learning_rate")
+      dtype = learning_rate.dtype
+      first_decay_steps = math_ops.cast(first_decay_steps, dtype)
+      alpha = math_ops.cast(alpha, dtype)
+      t_mul = math_ops.cast(t_mul, dtype)
+      m_mul = math_ops.cast(m_mul, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      completed_fraction = global_step_recomp / first_decay_steps
+
+      def compute_step(completed_fraction, geometric=False):
+        """Helper for `cond` operation."""
+        if geometric:
+          i_restart = math_ops.floor(
+              math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
+              math_ops.log(t_mul))
+
+          sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
+          completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
+
+        else:
+          i_restart = math_ops.floor(completed_fraction)
+          completed_fraction -= i_restart
+
+        return i_restart, completed_fraction
+
+      i_restart, completed_fraction = control_flow_ops.cond(
+          math_ops.equal(t_mul, 1.0),
+          lambda: compute_step(completed_fraction, geometric=False),
+          lambda: compute_step(completed_fraction, geometric=True))
+
+      m_fac = m_mul**i_restart
+      cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
+          constant_op.constant(math.pi) * completed_fraction))
+      decayed = (1 - alpha) * cosine_decayed + alpha
+
+      return math_ops.multiply(learning_rate, decayed, name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step,
+                           first_decay_steps, t_mul, m_mul, alpha, name)
+
+
+@tf_export("train.linear_cosine_decay", v1=[])
+def linear_cosine_decay(learning_rate,
+                        global_step,
+                        decay_steps,
+                        num_periods=0.5,
+                        alpha=0.0,
+                        beta=0.001,
+                        name=None):
+  """Applies linear cosine decay to the learning rate.
+
+  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+  https://arxiv.org/abs/1709.07417
+
+  For the idea of warm starts here controlled by `num_periods`,
+  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+  with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+  Note that linear cosine decay is more aggressive than cosine decay and
+  larger initial learning rates can typically be used.
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies a linear cosine decay function
+  to a provided initial learning rate.  It requires a `global_step` value to
+  compute the decayed learning rate.  You can just pass a TensorFlow variable
+  that you increment at each training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  global_step = min(global_step, decay_steps)
+  linear_decay = (decay_steps - global_step) / decay_steps)
+  cosine_decay = 0.5 * (
+      1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+  decayed = (alpha + linear_decay) * cosine_decay + beta
+  decayed_learning_rate = learning_rate * decayed
+  ```
+
+  Example usage:
+  ```python
+  decay_steps = 1000
+  lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step,
+                                               decay_steps)
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+      The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.
+    decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Number of steps to decay over.
+    num_periods: Number of periods in the cosine part of the decay.
+      See computation above.
+    alpha: See computation above.
+    beta: See computation above.
+    name: String.  Optional name of the operation.  Defaults to
+      'LinearCosineDecay'.
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("linear cosine decay requires global_step")
+  def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha,
+                 beta, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "LinearCosineDecay",
+                        [learning_rate, global_step]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+      num_periods = math_ops.cast(num_periods, dtype)
+      alpha = math_ops.cast(alpha, dtype)
+      beta = math_ops.cast(beta, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+      completed_fraction = global_step_recomp / decay_steps
+      fraction = 2.0 * num_periods * completed_fraction
+      cosine_decayed = 0.5 * (
+          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+
+      linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
+      return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           num_periods, alpha, beta, name)
+
+
+@tf_export("train.noisy_linear_cosine_decay", v1=[])
+def noisy_linear_cosine_decay(learning_rate,
+                              global_step,
+                              decay_steps,
+                              initial_variance=1.0,
+                              variance_decay=0.55,
+                              num_periods=0.5,
+                              alpha=0.0,
+                              beta=0.001,
+                              name=None):
+  """Applies noisy linear cosine decay to the learning rate.
+
+  See [Bello et al., ICML2017] Neural Optimizer Search with RL.
+  https://arxiv.org/abs/1709.07417
+
+  For the idea of warm starts here controlled by `num_periods`,
+  see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
+  with Warm Restarts. https://arxiv.org/abs/1608.03983
+
+  Note that linear cosine decay is more aggressive than cosine decay and
+  larger initial learning rates can typically be used.
+
+  When training a model, it is often recommended to lower the learning rate as
+  the training progresses.  This function applies a noisy linear
+  cosine decay function to a provided initial learning rate.
+  It requires a `global_step` value to compute the decayed learning rate.
+  You can just pass a TensorFlow variable that you increment at each
+  training step.
+
+  The function returns a no-arg callable that produces the decayed learning
+  rate. This can be useful for changing the learning rate value across
+  different invocations of optimizer functions. It is computed as:
+
+  ```python
+  global_step = min(global_step, decay_steps)
+  linear_decay = (decay_steps - global_step) / decay_steps)
+  cosine_decay = 0.5 * (
+      1 + cos(pi * 2 * num_periods * global_step / decay_steps))
+  decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
+  decayed_learning_rate = learning_rate * decayed
+  ```
+  where eps_t is 0-centered gaussian noise with variance
+  initial_variance / (1 + global_step) ** variance_decay
+
+  Example usage:
+  ```python
+  decay_steps = 1000
+  lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step,
+                                                     decay_steps)
+  ```
+
+  Args:
+    learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
+      The initial learning rate.
+    global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Global step to use for the decay computation.
+    decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+      Number of steps to decay over.
+    initial_variance: initial variance for the noise. See computation above.
+    variance_decay: decay for the noise's variance. See computation above.
+    num_periods: Number of periods in the cosine part of the decay.
+      See computation above.
+    alpha: See computation above.
+    beta: See computation above.
+    name: String.  Optional name of the operation.  Defaults to
+      'NoisyLinearCosineDecay'.
+  Returns:
+    A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
+    of the same type as `learning_rate`.
+  Raises:
+    ValueError: if `global_step` is not supplied.
+  """
+  if global_step is None:
+    raise ValueError("noisy linear cosine decay requires global_step")
+  def decayed_lr(learning_rate, global_step, decay_steps, initial_variance,
+                 variance_decay, num_periods, alpha, beta, name):
+    """Helper to recompute learning rate; most helpful in eager-mode."""
+    with ops.name_scope(name, "NoisyLinearCosineDecay",
+                        [learning_rate, global_step]) as name:
+      learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+      dtype = learning_rate.dtype
+      decay_steps = math_ops.cast(decay_steps, dtype)
+      initial_variance = math_ops.cast(initial_variance, dtype)
+      variance_decay = math_ops.cast(variance_decay, dtype)
+      num_periods = math_ops.cast(num_periods, dtype)
+      alpha = math_ops.cast(alpha, dtype)
+      beta = math_ops.cast(beta, dtype)
+
+      global_step_recomp = math_ops.cast(global_step, dtype)
+      global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+      linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+      variance = initial_variance / (
+          math_ops.pow(1.0 + global_step_recomp, variance_decay))
+      std = math_ops.sqrt(variance)
+      noisy_linear_decayed = (
+          linear_decayed + random_ops.random_normal(
+              linear_decayed.shape, stddev=std))
+
+      completed_fraction = global_step_recomp / decay_steps
+      fraction = 2.0 * num_periods * completed_fraction
+      cosine_decayed = 0.5 * (
+          1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+      noisy_linear_cosine_decayed = (
+          (alpha + noisy_linear_decayed) * cosine_decayed + beta)
+
+      return math_ops.multiply(
+          learning_rate, noisy_linear_cosine_decayed, name=name)
+
+  return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
+                           initial_variance, variance_decay, num_periods, alpha,
+                           beta, name)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
new file mode 100644
index 0000000..0f2d60d
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -0,0 +1,497 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+# Import resource_variable_ops for the variables-to-tensor implicit conversion.
+from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import learning_rate_decay_v2
+
+
+class LRDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testContinuous(self):
+    self.evaluate(variables.global_variables_initializer())
+    step = 5
+    decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96)
+    expected = .05 * 0.96**(5.0 / 10.0)
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testStaircase(self):
+    if context.executing_eagerly():
+      step = resource_variable_ops.ResourceVariable(0)
+      self.evaluate(variables.global_variables_initializer())
+      decayed_lr = learning_rate_decay_v2.exponential_decay(
+          .1, step, 3, 0.96, staircase=True)
+
+      # No change to learning rate due to staircase
+      expected = .1
+      self.evaluate(step.assign(1))
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+      expected = .1
+      self.evaluate(step.assign(2))
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+      # Decayed learning rate
+      expected = .1 * 0.96 ** (100 // 3)
+      self.evaluate(step.assign(100))
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  def testVariables(self):
+    with self.test_session():
+      step = variables.Variable(1)
+      assign_1 = step.assign(1)
+      assign_2 = step.assign(2)
+      assign_100 = step.assign(100)
+      decayed_lr = learning_rate_decay_v2.exponential_decay(.1, step, 3, 0.96,
+                                                            staircase=True)
+      variables.global_variables_initializer().run()
+      # No change to learning rate
+      assign_1.op.run()
+      self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+      assign_2.op.run()
+      self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
+      # Decayed learning rate
+      assign_100.op.run()
+      expected = .1 * 0.96 ** (100 // 3)
+      self.assertAllClose(decayed_lr().eval(), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testPiecewiseConstant(self):
+    x = resource_variable_ops.ResourceVariable(-999)
+    decayed_lr = learning_rate_decay_v2.piecewise_constant(
+        x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])
+
+    self.evaluate(variables.global_variables_initializer())
+
+    self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+    self.evaluate(x.assign(100))
+    self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
+    self.evaluate(x.assign(105))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+    self.evaluate(x.assign(110))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
+    self.evaluate(x.assign(120))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6)
+    self.evaluate(x.assign(999))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testPiecewiseConstantEdgeCases(self):
+    x_int = resource_variable_ops.ResourceVariable(
+        0, dtype=variables.dtypes.int32)
+    boundaries, values = [-1.0, 1.0], [1, 2, 3]
+    with self.assertRaises(ValueError):
+      decayed_lr = learning_rate_decay_v2.piecewise_constant(
+          x_int, boundaries, values)
+      decayed_lr()
+
+    x = resource_variable_ops.ResourceVariable(0.0)
+    boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
+    with self.assertRaises(ValueError):
+      decayed_lr = learning_rate_decay_v2.piecewise_constant(
+          x, boundaries, values)()
+      decayed_lr()
+
+    # Test that ref types are valid.
+    if not context.executing_eagerly():
+      x = variables.Variable(0.0)
+      x_ref = x.op.outputs[0]   # float32_ref tensor should be accepted
+      boundaries, values = [1.0, 2.0], [1, 2, 3]
+      learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values)
+
+    # Test casting boundaries from int32 to int64.
+    x_int64 = resource_variable_ops.ResourceVariable(
+        0, dtype=variables.dtypes.int64)
+    boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
+    decayed_lr = learning_rate_decay_v2.piecewise_constant(
+        x_int64, boundaries, values)
+
+    self.evaluate(variables.global_variables_initializer())
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+    self.evaluate(x_int64.assign(1))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
+    self.evaluate(x_int64.assign(2))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6)
+    self.evaluate(x_int64.assign(3))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6)
+    self.evaluate(x_int64.assign(4))
+    self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6)
+
+
+class LinearDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testHalfWay(self):
+    step = 5
+    lr = 0.05
+    end_lr = 0.0
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+    expected = lr * 0.5
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testEnd(self):
+    step = 10
+    lr = 0.05
+    end_lr = 0.001
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+    expected = end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testHalfWayWithEnd(self):
+    step = 5
+    lr = 0.05
+    end_lr = 0.001
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+    expected = (lr + end_lr) * 0.5
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBeyondEnd(self):
+    step = 15
+    lr = 0.05
+    end_lr = 0.001
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
+    expected = end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBeyondEndWithCycle(self):
+    step = 15
+    lr = 0.05
+    end_lr = 0.001
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, cycle=True)
+    expected = (lr - end_lr) * 0.25 + end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class SqrtDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testHalfWay(self):
+    step = 5
+    lr = 0.05
+    end_lr = 0.0
+    power = 0.5
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, power=power)
+    expected = lr * 0.5**power
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testEnd(self):
+    step = 10
+    lr = 0.05
+    end_lr = 0.001
+    power = 0.5
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, power=power)
+    expected = end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testHalfWayWithEnd(self):
+    step = 5
+    lr = 0.05
+    end_lr = 0.001
+    power = 0.5
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, power=power)
+    expected = (lr - end_lr) * 0.5**power + end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBeyondEnd(self):
+    step = 15
+    lr = 0.05
+    end_lr = 0.001
+    power = 0.5
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, power=power)
+    expected = end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBeyondEndWithCycle(self):
+    step = 15
+    lr = 0.05
+    end_lr = 0.001
+    power = 0.5
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, 10, end_lr, power=power, cycle=True)
+    expected = (lr - end_lr) * 0.25**power + end_lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class PolynomialDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testBeginWithCycle(self):
+    lr = 0.001
+    decay_steps = 10
+    step = 0
+    decayed_lr = learning_rate_decay_v2.polynomial_decay(
+        lr, step, decay_steps, cycle=True)
+    expected = lr
+    self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class ExponentialDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDecay(self):
+    initial_lr = 0.1
+    k = 10
+    decay_rate = 0.96
+    step = resource_variable_ops.ResourceVariable(0)
+    decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k,
+                                                          decay_rate)
+
+    self.evaluate(variables.global_variables_initializer())
+    for i in range(k + 1):
+      expected = initial_lr * math.exp(-i / k * decay_rate)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+      self.evaluate(step.assign_add(1))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testStaircase(self):
+    initial_lr = 0.1
+    k = 10
+    decay_rate = 0.96
+    step = resource_variable_ops.ResourceVariable(0)
+    decayed_lr = learning_rate_decay_v2.natural_exp_decay(
+        initial_lr, step, k, decay_rate, staircase=True)
+
+    self.evaluate(variables.global_variables_initializer())
+    for i in range(k + 1):
+      expected = initial_lr * math.exp(-decay_rate * (i // k))
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+      self.evaluate(step.assign_add(1))
+
+
+class InverseDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDecay(self):
+    initial_lr = 0.1
+    k = 10
+    decay_rate = 0.96
+    step = resource_variable_ops.ResourceVariable(0)
+    decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k,
+                                                           decay_rate)
+
+    self.evaluate(variables.global_variables_initializer())
+    for i in range(k + 1):
+      expected = initial_lr / (1 + i / k * decay_rate)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+      self.evaluate(step.assign_add(1))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testStaircase(self):
+    initial_lr = 0.1
+    k = 10
+    decay_rate = 0.96
+    step = resource_variable_ops.ResourceVariable(0)
+    decayed_lr = learning_rate_decay_v2.inverse_time_decay(
+        initial_lr, step, k, decay_rate, staircase=True)
+
+    self.evaluate(variables.global_variables_initializer())
+    for i in range(k + 1):
+      expected = initial_lr / (1 + decay_rate * (i // k))
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+      self.evaluate(step.assign_add(1))
+
+
+class CosineDecayTestV2(test_util.TensorFlowTestCase):
+
+  def np_cosine_decay(self, step, decay_steps, alpha=0.0):
+    step = min(step, decay_steps)
+    completed_fraction = step / decay_steps
+    decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+    return (1.0 - alpha) * decay + alpha
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDecay(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+                                                       num_training_steps)
+      expected = self.np_cosine_decay(step, num_training_steps)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testAlpha(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    alpha = 0.1
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
+                                                       num_training_steps,
+                                                       alpha)
+      expected = self.np_cosine_decay(step, num_training_steps, alpha)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase):
+
+  def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
+                               alpha=0.0):
+    fac = 1.0
+    while step >= decay_steps:
+      step -= decay_steps
+      decay_steps *= t_mul
+      fac *= m_mul
+
+    completed_fraction = step / decay_steps
+    decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
+    return (1.0 - alpha) * decay + alpha
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDecay(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+          initial_lr, step, num_training_steps)
+      expected = self.np_cosine_decay_restarts(step, num_training_steps)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testAlpha(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    alpha = 0.1
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+          initial_lr, step, num_training_steps, alpha=alpha)
+      expected = self.np_cosine_decay_restarts(
+          step, num_training_steps, alpha=alpha)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testMMul(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    m_mul = 0.9
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+          initial_lr, step, num_training_steps, m_mul=m_mul)
+      expected = self.np_cosine_decay_restarts(
+          step, num_training_steps, m_mul=m_mul)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testTMul(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    t_mul = 1.0
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
+          initial_lr, step, num_training_steps, t_mul=t_mul)
+      expected = self.np_cosine_decay_restarts(
+          step, num_training_steps, t_mul=t_mul)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class LinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+  def np_linear_cosine_decay(self,
+                             step,
+                             decay_steps,
+                             alpha=0.0,
+                             beta=0.001,
+                             num_periods=0.5):
+    step = min(step, decay_steps)
+    linear_decayed = float(decay_steps - step) / decay_steps
+    fraction = 2.0 * num_periods * step / float(decay_steps)
+    cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
+    return (alpha + linear_decayed) * cosine_decayed + beta
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDefaultDecay(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+          initial_lr, step, num_training_steps)
+      expected = self.np_linear_cosine_decay(step, num_training_steps)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testNonDefaultDecay(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
+          initial_lr,
+          step,
+          num_training_steps,
+          alpha=0.1,
+          beta=1e-4,
+          num_periods=5)
+      expected = self.np_linear_cosine_decay(
+          step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
+      self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
+
+
+class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes
+  def testDefaultNoisyLinearCosine(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      # No numerical check because of noise
+      decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+          initial_lr, step, num_training_steps)
+      # Cannot be deterministically tested
+      self.evaluate(decayed_lr())
+
+  @test_util.run_in_graph_and_eager_modes
+  def testNonDefaultNoisyLinearCosine(self):
+    num_training_steps = 1000
+    initial_lr = 1.0
+    for step in range(0, 1500, 250):
+      # No numerical check because of noise
+      decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
+          initial_lr,
+          step,
+          num_training_steps,
+          initial_variance=0.5,
+          variance_decay=0.1,
+          alpha=0.1,
+          beta=1e-4,
+          num_periods=5)
+      # Cannot be deterministically tested
+      self.evaluate(decayed_lr())
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index f7e7807..8a21c39 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -123,7 +123,7 @@
           ]), self.evaluate(var1))
 
   def testBasic(self):
-    with self.test_session():
+    with self.cached_session():
       self.doTestBasic(use_resource=False)
 
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -162,7 +162,7 @@
 
   def testNesterovMomentum(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -188,7 +188,7 @@
 
   def testSparseNesterovMomentum(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
         var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
         accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
@@ -282,7 +282,7 @@
 
   def testTensorLearningRateAndMomentum(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -435,7 +435,7 @@
     return db_grad, db_out
 
   def testLikeDistBeliefMom01(self):
-    with self.test_session():
+    with self.cached_session():
       db_grad, db_out = self._dbParamsMom01()
       num_samples = len(db_grad)
       var0 = variables.Variable([0.0] * num_samples)
@@ -449,7 +449,7 @@
 
   def testSparse(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
         var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
         grads0 = ops.IndexedSlices(
@@ -518,7 +518,7 @@
 
   def testSharing(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index ff586b6..2d7799d 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -80,7 +80,7 @@
       self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
       self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
       self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         self.assertItemsEqual([b'my_var', b'my_local_var'],
                               sess.run(scaffold.ready_op))
         self.assertItemsEqual([b'my_var'],
@@ -513,21 +513,21 @@
   """_WrappedSession tests."""
 
   def test_properties(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       constant_op.constant(0.0)
       wrapped_sess = monitored_session._WrappedSession(sess)
       self.assertEquals(sess.graph, wrapped_sess.graph)
       self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
 
   def test_should_stop_on_close(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       wrapped_sess = monitored_session._WrappedSession(sess)
       self.assertFalse(wrapped_sess.should_stop())
       wrapped_sess.close()
       self.assertTrue(wrapped_sess.should_stop())
 
   def test_should_stop_uses_check_stop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       wrapped_sess = StopAtNSession(sess, 3)
       self.assertFalse(wrapped_sess.should_stop())
       self.assertFalse(wrapped_sess.should_stop())
@@ -535,7 +535,7 @@
       self.assertTrue(wrapped_sess.should_stop())
 
   def test_should_stop_delegates_to_wrapped_session(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       wrapped_sess0 = StopAtNSession(sess, 4)
       wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
       self.assertFalse(wrapped_sess1.should_stop())
@@ -545,7 +545,7 @@
       self.assertTrue(wrapped_sess1.should_stop())
 
   def test_close_twice(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       wrapped_sess = monitored_session._WrappedSession(sess)
       wrapped_sess.close()
       self.assertTrue(wrapped_sess.should_stop())
@@ -553,7 +553,7 @@
       self.assertTrue(wrapped_sess.should_stop())
 
   def test_run(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(0)
       v = array_ops.identity(c)
       self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
@@ -570,7 +570,7 @@
   """_CoordinatedSession tests."""
 
   def test_properties(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       constant_op.constant(0.0)
       coord = coordinator.Coordinator()
       coord_sess = monitored_session._CoordinatedSession(sess, coord)
@@ -578,7 +578,7 @@
       self.assertEquals(sess.sess_str, coord_sess.sess_str)
 
   def test_run(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(0)
       v = array_ops.identity(c)
       coord = coordinator.Coordinator()
@@ -586,7 +586,7 @@
       self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
 
   def test_should_stop_on_close(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       coord_sess = monitored_session._CoordinatedSession(sess, coord)
       self.assertFalse(coord_sess.should_stop())
@@ -594,7 +594,7 @@
       self.assertTrue(coord_sess.should_stop())
 
   def test_should_stop_on_coord_stop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       coord_sess = monitored_session._CoordinatedSession(sess, coord)
       self.assertFalse(coord_sess.should_stop())
@@ -602,7 +602,7 @@
       self.assertTrue(coord_sess.should_stop())
 
   def test_dont_request_stop_on_exception_in_main_thread(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(0)
       v = array_ops.identity(c)
       coord = coordinator.Coordinator()
@@ -616,7 +616,7 @@
       self.assertFalse(coord_sess.should_stop())
 
   def test_stop_threads_on_close_after_exception(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(0)
       v = array_ops.identity(c)
       coord = coordinator.Coordinator()
@@ -646,7 +646,7 @@
       self.assertTrue(coord_sess.should_stop())
 
   def test_stop_threads_on_close(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator()
       threads = [
           threading.Thread(
@@ -664,7 +664,7 @@
 
   def test_propagates_exception_trace(self):
     assertion = control_flow_ops.Assert(False, ['This should fail.'])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       coord = coordinator.Coordinator(clean_stop_exception_types=())
       coord_sess = monitored_session._CoordinatedSession(sess, coord)
       try:
@@ -810,7 +810,7 @@
       return self._sess
 
   def test_properties(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       constant_op.constant(0.0)
       recoverable_sess = monitored_session._RecoverableSession(
           self._SessionReturner(sess))
@@ -818,7 +818,7 @@
       self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
 
   def test_run(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       c = constant_op.constant(0)
       v = array_ops.identity(c)
       recoverable_sess = monitored_session._RecoverableSession(
@@ -826,7 +826,7 @@
       self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
 
   def test_recovery(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
 
       class StackSessionCreator(object):
 
@@ -872,7 +872,7 @@
         recoverable_sess.run(v, feed_dict={c: -12})
 
   def test_recovery_from_coordinator_exception(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = monitored_session.MonitoredSession(
           session_creator,
@@ -897,7 +897,7 @@
       self.assertEqual(2, session_creator.number_of_sessions_created)
 
   def test_recovery_from_non_preemption_in_coordinator(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       hook = StopCoordinatorWithException(
           calls_before_stopping=2,
@@ -926,7 +926,7 @@
         session.close()
 
   def test_recovery_from_session_getting_stuck(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = monitored_session.MonitoredSession(
           session_creator,
@@ -950,7 +950,7 @@
       self.assertEqual(2, session_creator.number_of_sessions_created)
 
   def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = monitored_session.MonitoredSession(
           session_creator,
@@ -980,7 +980,7 @@
       self.assertEqual(2, session_creator.number_of_sessions_created)
 
   def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       hook = StopCoordinatorWithException(
           calls_before_stopping=2,
@@ -1014,7 +1014,7 @@
         session.close()
 
   def test_recovery_from_session_getting_stuck_when_run_hooks(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = monitored_session.MonitoredSession(
           session_creator,
@@ -1058,7 +1058,7 @@
     return session
 
   def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = self.create_raw_session_with_failing_coordinator(
           session_creator,
@@ -1090,7 +1090,7 @@
       self.assertEqual(2, session_creator.number_of_sessions_created)
 
   def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = self.create_raw_session_with_failing_coordinator(
           session_creator,
@@ -1127,7 +1127,7 @@
         session.close()
 
   def test_recovery_from_session_getting_stuck_with_raw_session(self):
-    with self.test_session() as test_session:
+    with self.cached_session() as test_session:
       session_creator = CountingSessionCreator(test_session)
       session = self.create_raw_session_with_failing_coordinator(
           session_creator,
@@ -2047,7 +2047,7 @@
 
         return value
 
-      with self.test_session() as test_session:
+      with self.cached_session() as test_session:
         with monitored_session.MonitoredSession(
             CountingSessionCreator(test_session)) as session:
           session.run(variables.global_variables_initializer())
@@ -2110,7 +2110,7 @@
         step_context.session.run(graph_side_effect)
         return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
 
-      with self.test_session() as test_session:
+      with self.cached_session() as test_session:
         with monitored_session.MonitoredSession(
             CountingSessionCreator(test_session),
             hooks=[Hook(self)]) as session:
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index fdb8d79..93991d0 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -35,7 +35,7 @@
 class MovingAveragesTest(test.TestCase):
 
   def testAssignMovingAverageWithoutZeroDebias(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable([10.0, 11.0])
       val = constant_op.constant([1.0, 2.0], dtypes.float32)
       decay = 0.25
@@ -49,7 +49,7 @@
           var.eval())
 
   def testAssignMovingAverage(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable([0.0, 0.0])
       val = constant_op.constant([1.0, 2.0], dtypes.float32)
       decay = 0.25
@@ -86,7 +86,7 @@
       moving_averages.assign_moving_average(var, 0.0, 0.99)
 
   def testWeightedMovingAverage(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       decay = 0.5
       weight = array_ops.placeholder(dtypes.float32, [])
       val = array_ops.placeholder(dtypes.float32, [])
@@ -187,53 +187,53 @@
     self.assertAllClose(expected, avg2.eval())
 
   def testAverageVariablesNoNumUpdates_Scalar(self):
-    with self.test_session():
+    with self.cached_session():
       ema = moving_averages.ExponentialMovingAverage(0.25)
       self._CheckDecay(ema, actual_decay=0.25, dim=1)
 
   def testAverageVariablesNoNumUpdates_Scalar_Debias(self):
-    with self.test_session():
+    with self.cached_session():
       ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
       self._CheckDecay(ema, actual_decay=0.25, dim=1)
 
   def testAverageVariablesNoNumUpdates_Vector(self):
-    with self.test_session():
+    with self.cached_session():
       ema = moving_averages.ExponentialMovingAverage(0.25)
       self._CheckDecay(ema, actual_decay=0.25, dim=5)
 
   def testAverageVariablesNoNumUpdates_Vector_Debias(self):
-    with self.test_session():
+    with self.cached_session():
       ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
       self._CheckDecay(ema, actual_decay=0.25, dim=5)
 
   def testAverageVariablesNumUpdates_Scalar(self):
-    with self.test_session():
+    with self.cached_session():
       # With num_updates 1, the decay applied is 0.1818
       ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
       self._CheckDecay(ema, actual_decay=0.181818, dim=1)
 
   def testAverageVariablesNumUpdates_Scalar_Debias(self):
-    with self.test_session():
+    with self.cached_session():
       # With num_updates 1, the decay applied is 0.1818
       ema = moving_averages.ExponentialMovingAverage(
           0.25, num_updates=1, zero_debias=True)
       self._CheckDecay(ema, actual_decay=0.181818, dim=1)
 
   def testAverageVariablesNumUpdates_Vector(self):
-    with self.test_session():
+    with self.cached_session():
       # With num_updates 1, the decay applied is 0.1818
       ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
       self._CheckDecay(ema, actual_decay=0.181818, dim=5)
 
   def testAverageVariablesNumUpdates_Vector_Debias(self):
-    with self.test_session():
+    with self.cached_session():
       # With num_updates 1, the decay applied is 0.1818
       ema = moving_averages.ExponentialMovingAverage(
           0.25, num_updates=1, zero_debias=True)
       self._CheckDecay(ema, actual_decay=0.181818, dim=5)
 
   def testAverageVariablesWithControlDeps(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variables.Variable(0, name="v0")
       add_to_v0 = v0.assign_add(1)
       v1 = variables.Variable([10.0], name="v1")
@@ -276,7 +276,7 @@
     self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
 
   def averageVariablesNamesHelper(self, zero_debias):
-    with self.test_session():
+    with self.cached_session():
       v0 = variables.Variable(10.0, name="v0")
       v1 = variables.Variable(30.0, name="v1")
       # Add a non-trainable variable.
@@ -320,7 +320,7 @@
 
   def averageVariablesNamesRespectScopeHelper(self, zero_debias):
     # See discussion on #2740.
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("scope1"):
         v0 = variables.Variable(10.0, name="v0")
         v1 = variables.Variable(30.0, name="v1")
@@ -367,7 +367,7 @@
     self.averageVariablesNamesRespectScopeHelper(zero_debias=False)
 
   def testSubsetAverageVariablesNames(self):
-    with self.test_session():
+    with self.cached_session():
       v0 = variables.Variable(10.0, name="v0")
       v1 = variables.Variable(30.0, name="v1")
       # Add a non-trainable variable.
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index dfe9176..7a7d01d 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -64,7 +64,7 @@
 
   def testAggregationMethod(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         cost = 5 * var0 + 3 * var1
@@ -89,7 +89,7 @@
 
   def testPrecomputedGradient(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         cost = 5 * var0 + 3 * var1
@@ -231,7 +231,7 @@
       sgd_op.apply_gradients(grads_and_vars)
 
   def testTrainOp(self):
-    with self.test_session():
+    with self.cached_session():
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([3.0, 4.0])
       cost = 5 * var0 + 3 * var1
@@ -244,7 +244,7 @@
   def testConstraint(self):
     constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
     constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
-    with self.test_session():
+    with self.cached_session():
       var0 = variables.Variable([1.0, 2.0],
                                 constraint=constraint_01)
       var1 = variables.Variable([3.0, 4.0],
diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py
index 430c16b..74e06a5 100644
--- a/tensorflow/python/training/proximal_adagrad_test.py
+++ b/tensorflow/python/training/proximal_adagrad_test.py
@@ -35,7 +35,7 @@
 class ProximalAdagradOptimizerTest(test.TestCase):
 
   def doTestProximalAdagradwithoutRegularization(self, use_resource=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([0.0, 0.0])
       var1 = variables.Variable([0.0, 0.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -71,7 +71,7 @@
     self.doTestProximalAdagradwithoutRegularization(use_resource=True)
 
   def testProximalAdagradwithoutRegularization2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([4.0, 3.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -98,7 +98,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -114,7 +114,7 @@
             [[0, 1]], var0.eval(), atol=0.01)
 
   def testProximalAdagradWithL1(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([4.0, 3.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -140,7 +140,7 @@
       self.assertAllClose(np.array([2.959304, 1.029232]), v1_val)
 
   def testProximalAdagradWithL1_L2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([4.0, 3.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -206,7 +206,7 @@
     return v0_val, v1_val
 
   def testEquivAdagradwithoutRegularization(self):
-    with self.test_session():
+    with self.cached_session():
       val0, val1 = self.applyOptimizer(
           proximal_adagrad.ProximalAdagradOptimizer(
               3.0,
@@ -214,7 +214,7 @@
               l1_regularization_strength=0.0,
               l2_regularization_strength=0.0))
 
-    with self.test_session():
+    with self.cached_session():
       val2, val3 = self.applyOptimizer(
           adagrad.AdagradOptimizer(
               3.0, initial_accumulator_value=0.1))
@@ -223,7 +223,7 @@
     self.assertAllClose(val1, val3)
 
   def testEquivSparseAdagradwithoutRegularization(self):
-    with self.test_session():
+    with self.cached_session():
       val0, val1 = self.applyOptimizer(
           proximal_adagrad.ProximalAdagradOptimizer(
               3.0,
@@ -232,7 +232,7 @@
               l2_regularization_strength=0.0),
           is_sparse=True)
 
-    with self.test_session():
+    with self.cached_session():
       val2, val3 = self.applyOptimizer(
           adagrad.AdagradOptimizer(
               3.0, initial_accumulator_value=0.1),
diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py
index 4e4812f..f77f68b 100644
--- a/tensorflow/python/training/proximal_gradient_descent_test.py
+++ b/tensorflow/python/training/proximal_gradient_descent_test.py
@@ -36,7 +36,7 @@
 
   def doTestProximalGradientDescentwithoutRegularization(
       self, use_resource=False):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if use_resource:
         var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
         var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
@@ -69,7 +69,7 @@
     self.doTestProximalGradientDescentwithoutRegularization(use_resource=True)
 
   def testProximalGradientDescentwithoutRegularization2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([4.0, 3.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -94,7 +94,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -111,7 +111,7 @@
             [[-111, -138]], var0.eval(), atol=0.01)
 
   def testProximalGradientDescentWithL1_L2(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       var0 = variables.Variable([1.0, 2.0])
       var1 = variables.Variable([4.0, 3.0])
       grads0 = constant_op.constant([0.1, 0.2])
@@ -174,7 +174,7 @@
     return v0_val, v1_val
 
   def testEquivSparseGradientDescentwithoutRegularization(self):
-    with self.test_session():
+    with self.cached_session():
       val0, val1 = self.applyOptimizer(
           proximal_gradient_descent.ProximalGradientDescentOptimizer(
               3.0,
@@ -182,7 +182,7 @@
               l2_regularization_strength=0.0),
           is_sparse=True)
 
-    with self.test_session():
+    with self.cached_session():
       val2, val3 = self.applyOptimizer(
           gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True)
 
@@ -190,14 +190,14 @@
     self.assertAllClose(val1, val3)
 
   def testEquivGradientDescentwithoutRegularization(self):
-    with self.test_session():
+    with self.cached_session():
       val0, val1 = self.applyOptimizer(
           proximal_gradient_descent.ProximalGradientDescentOptimizer(
               3.0,
               l1_regularization_strength=0.0,
               l2_regularization_strength=0.0))
 
-    with self.test_session():
+    with self.cached_session():
       val2, val3 = self.applyOptimizer(
           gradient_descent.GradientDescentOptimizer(3.0))
 
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 900f970..9b9e28a 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -41,7 +41,7 @@
 class QueueRunnerTest(test.TestCase):
 
   def testBasic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var = variables.Variable(zero64)
@@ -61,7 +61,7 @@
       self.assertEqual(3, var.eval())
 
   def testTwoOps(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var0 = variables.Variable(zero64)
@@ -84,7 +84,7 @@
       self.assertEqual(30, var1.eval())
 
   def testExceptionsCaptured(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
       qr = queue_runner_impl.QueueRunner(queue, [_MockOp("i fail"),
                                                  _MockOp("so fail")])
@@ -100,7 +100,7 @@
       self.assertTrue("Operation not in the graph" in str(exceptions[1]))
 
   def testRealDequeueEnqueue(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       q0 = data_flow_ops.FIFOQueue(3, dtypes.float32)
       enqueue0 = q0.enqueue((10.0,))
       close0 = q0.close()
@@ -128,7 +128,7 @@
         dequeue1.eval()
 
   def testRespectCoordShouldStop(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var = variables.Variable(zero64)
@@ -152,7 +152,7 @@
       self.assertEqual(0, var.eval())
 
   def testRequestStopOnException(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
       qr = queue_runner_impl.QueueRunner(queue, [_MockOp("not an op")])
       coord = coordinator.Coordinator()
@@ -164,7 +164,7 @@
         coord.join()
 
   def testGracePeriod(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # The enqueue will quickly block.
       queue = data_flow_ops.FIFOQueue(2, dtypes.float32)
       enqueue = queue.enqueue((10.0,))
@@ -181,7 +181,7 @@
       coord.join(stop_grace_period_secs=1.0)
 
   def testMultipleSessions(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       with session.Session() as other_sess:
         zero64 = constant_op.constant(0, dtype=dtypes.int64)
         var = variables.Variable(zero64)
@@ -196,7 +196,7 @@
         self.assertEqual(len(threads), len(other_threads))
 
   def testIgnoreMultiStarts(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var = variables.Variable(zero64)
@@ -212,7 +212,7 @@
       self.assertEqual([], new_threads)
 
   def testThreads(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
       zero64 = constant_op.constant(0, dtype=dtypes.int64)
       var = variables.Variable(zero64)
@@ -256,7 +256,7 @@
     init_op = variables.global_variables_initializer()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     queue_runner_impl.add_queue_runner(qr)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       init_op.run()
       threads = queue_runner_impl.start_queue_runners(sess)
       for t in threads:
@@ -273,7 +273,7 @@
     init_op = variables.global_variables_initializer()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     queue_runner_impl.add_queue_runner(qr)
-    with self.test_session():
+    with self.cached_session():
       init_op.run()
       with self.assertRaisesRegexp(TypeError, "tf.Session"):
         queue_runner_impl.start_queue_runners("NotASession")
@@ -286,7 +286,7 @@
     init_op = variables.global_variables_initializer()
     qr = queue_runner_impl.QueueRunner(queue, [count_up_to])
     queue_runner_impl.add_queue_runner(qr)
-    with self.test_session():
+    with self.cached_session():
       init_op.run()
       threads = queue_runner_impl.start_queue_runners(
           monitored_session.MonitoredSession())
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 6043327..4f5f96e 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -165,7 +165,7 @@
 
   def testMinimizeSparseResourceVariable(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@@ -187,7 +187,7 @@
 
   def testMinimizeSparseResourceVariableCentered(self):
     for dtype in [dtypes.float32, dtypes.float64]:
-      with self.test_session():
+      with self.cached_session():
         var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
         x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
         pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f5b2a22..0ac8481 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -324,7 +324,7 @@
         save_relative_paths=True)
     init_all_op = [variables.global_variables_initializer(), v2_init]
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize all variables
       sess.run(init_all_op)
 
@@ -349,7 +349,7 @@
 
     # Start a second session.  In that session the parameter nodes
     # have not been initialized either.
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v0 = variables.Variable(-1.0, name="v0")
       v1 = variables.Variable(-1.0, name="v1")
       v2 = saver_test_utils.CheckpointedOp(name="v2")
@@ -373,7 +373,7 @@
     v0 = variables.Variable(0, name="v0")
     filename = b"somerandomfilename"
     save = saver_module.Saver({"v0": v0}, filename=filename)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tensor = sess.graph.get_tensor_by_name(
           save.saver_def.filename_tensor_name)
       self.assertEqual(sess.run(tensor), filename)
@@ -381,7 +381,7 @@
   def testInvalidPath(self):
     v0 = variables.Variable(0, name="v0")
     for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         save = saver_module.Saver({"v0": v0}, write_version=ver)
         with self.assertRaisesRegexp(
             ValueError, "The passed save_path is not a valid checkpoint:"):
@@ -390,7 +390,7 @@
   def testInt64(self):
     save_path = os.path.join(self.get_temp_dir(), "int64")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Build a graph with 1 node, and save and restore for them.
       v = variables.Variable(np.int64(15), name="v")
       save = saver_module.Saver({"v": v}, restore_sequentially=True)
@@ -401,7 +401,7 @@
       self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         v = variables.Variable(np.int64(-1), name="v")
         save = saver_module.Saver({"v": v})
 
@@ -559,12 +559,12 @@
 
   def testAllowEmpty(self):
     save_path = os.path.join(self.get_temp_dir(), "allow_empty")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       _ = constant_op.constant(1)
       save = saver_module.Saver(allow_empty=True)
       val = save.save(sess, save_path)
       self.assertIsNone(val)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       save = saver_module.Saver(allow_empty=True)
       save.restore(sess, save_path)
 
@@ -740,7 +740,7 @@
       # save succeeds or fails is implementation dependent.  Therefore we allow
       # both cases.
       try:
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           # Initialize all variables
           sess.run(init_all_op)
 
@@ -751,7 +751,7 @@
           # Save the graph.
           save.save(sess, save_path)
 
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           # Restore the saved values in the parameter nodes.
           save.restore(sess, save_path)
           # Check that the parameter nodes have been restored.
@@ -775,7 +775,7 @@
     save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
     init_all_op = variables.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initialize all variables
       sess.run(init_all_op)
 
@@ -983,7 +983,7 @@
           os.path.join(self.get_temp_dir(), "sharded_basics"))
 
   def testSaverDef(self):
-    with self.test_session():
+    with self.cached_session():
       v0 = variables.Variable(123, name="v0")
       save = saver_module.Saver({"v0": v0}, sharded=True)
       sd = save.as_saver_def()
@@ -1209,7 +1209,7 @@
   def testNonSharded(self):
     save_dir = self._get_test_dir("max_to_keep_non_sharded")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = variables.Variable(10.0, name="v")
       save = saver_module.Saver({"v": v}, max_to_keep=2)
       variables.global_variables_initializer().run()
@@ -1447,7 +1447,7 @@
     save_dir = self._get_test_dir("no_max_to_keep")
     save_dir2 = self._get_test_dir("max_to_keep_0")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = variables.Variable(10.0, name="v")
       variables.global_variables_initializer().run()
 
@@ -1474,7 +1474,7 @@
   def testNoMetaGraph(self):
     save_dir = self._get_test_dir("no_meta_graph")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = variables.Variable(10.0, name="v")
       save = saver_module.Saver({"v": v})
       variables.global_variables_initializer().run()
@@ -1497,7 +1497,7 @@
   def testNonSharded(self, mock_time):
     save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       v = variable_scope.variable([10.0], name="v")
       # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
       # call, which throws the test timing off in fastbuild mode.
@@ -1630,7 +1630,7 @@
   def testAddCollectionDef(self):
     test_dir = self._get_test_dir("good_collection")
     filename = os.path.join(test_dir, "metafile")
-    with self.test_session():
+    with self.cached_session():
       # Creates a graph.
       v0 = variables.Variable(1.0, name="v0")
       control_flow_ops.cond(
@@ -1685,7 +1685,7 @@
         self, meta_graph_def, new_meta_graph_def)
 
   def testAddCollectionDefFails(self):
-    with self.test_session():
+    with self.cached_session():
       # Creates a graph.
       v0 = variables.Variable(10.0, name="v0")
       # Creates a saver.
@@ -1870,7 +1870,7 @@
   def testSliceVariable(self):
     test_dir = self._get_test_dir("slice_saver")
     filename = os.path.join(test_dir, "metafile")
-    with self.test_session():
+    with self.cached_session():
       v1 = variables.Variable([20.0], name="v1")
       v2 = variables.Variable([20.0], name="v2")
       v2._set_save_slice_info(
@@ -1946,7 +1946,7 @@
       ops_lib.add_to_collection("logits", logits)
     init_all_op = variables.global_variables_initializer()
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # Initializes all the variables.
       sess.run(init_all_op)
       # Runs to logit.
@@ -2120,7 +2120,7 @@
     # pylint: enable=g-long-lambda
 
   def testStrippedOpListDef(self):
-    with self.test_session():
+    with self.cached_session():
       # Creates a graph.
       v0 = variables.Variable(0.0)
       var = variables.Variable(10.0)
@@ -2160,7 +2160,7 @@
 
     # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
     # (complex64) in the "Complex" op must be removed.
-    with self.test_session():
+    with self.cached_session():
       real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
       imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
       math_ops.complex(real_num, imag_num, name="complex")
@@ -2397,7 +2397,7 @@
         }, write_version=self._WRITE_VERSION)
     save_path = os.path.join(self.get_temp_dir(),
                              "ckpt_for_debug_string" + str(self._WRITE_VERSION))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(init_all_op)
       # Saves a checkpoint.
       save.save(sess, save_path)
@@ -2853,7 +2853,7 @@
     saver = saver_module.Saver(var_list=[v])
     test_dir = self.get_temp_dir()
     prefix = os.path.join(test_dir, "ckpt")
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.evaluate(v.non_dep_variable.assign(42.))
       save_path = saver.save(sess, prefix)
       self.evaluate(v.non_dep_variable.assign(43.))
@@ -2867,7 +2867,7 @@
     test_dir = self.get_temp_dir()
     prefix = os.path.join(test_dir, "ckpt")
     self.evaluate(v.non_dep_variable.assign(42.))
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       save_path = saver.save(sess, prefix)
       self.evaluate(v.non_dep_variable.assign(43.))
       self.evaluate(v.mirrored.assign(44.))
@@ -2900,7 +2900,7 @@
       saver = saver_module.Saver(var_list=[v])
       test_dir = self.get_temp_dir()
       prefix = os.path.join(test_dir, "ckpt")
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         save_path = saver.save(sess, prefix)
         self.assertEqual(1, v.eval_count)
         saver.restore(sess, save_path)
@@ -2957,7 +2957,7 @@
     b = resource_variable_ops.ResourceVariable(1., name="b")
     a_saver = saver_module.Saver([a])
     b_saver = saver_module.Saver([b])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       sess.run(a.initializer)
       save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
       with self.assertRaisesRegexp(
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index d7e6dac..f1d18f7 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -98,7 +98,7 @@
       os.rename(checkpoint_dir, checkpoint_dir2)
       gfile.MakeDirs(checkpoint_dir)
       v = variables.Variable([6.0, 7.0, 8.0], name="v")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
       session_manager.SessionManager(
           ready_op=variables.report_uninitialized_variables())
@@ -236,7 +236,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -294,7 +294,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -326,7 +326,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
           ready_op=variables.report_uninitialized_variables(),
@@ -362,7 +362,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -467,7 +467,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="x")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
         self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -519,7 +519,7 @@
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="x_res")
 
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
         self.assertEqual(False, variables.is_variable_initialized(x).eval())
@@ -566,7 +566,7 @@
     with ops.Graph().as_default():
       i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
       v = variables.Variable(array_ops.identity(i), name="v")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
       sm = session_manager.SessionManager(
           ready_op=variables.report_uninitialized_variables())
@@ -585,7 +585,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -602,7 +602,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -619,7 +619,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -640,7 +640,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="w")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
         self.assertEqual(False, variables.is_variable_initialized(w).eval())
       sm2 = session_manager.SessionManager(
@@ -714,7 +714,7 @@
       os.rename(checkpoint_dir, checkpoint_dir2)
       gfile.MakeDirs(checkpoint_dir)
       v = variables.Variable([6.0, 7.0, 8.0], name="v")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
       session_manager.SessionManager(
           ready_op=variables.assert_variables_initialized())
@@ -769,7 +769,7 @@
     # Create a new Graph and SessionManager and recover.
     with ops.Graph().as_default():
       v = variables.Variable(2, name="v")
-      with self.test_session():
+      with self.cached_session():
         self.assertEqual(False, variables.is_variable_initialized(v).eval())
       sm2 = session_manager.SessionManager(
           ready_op=variables.assert_variables_initialized())
diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py
index 08a3c8d..6d63641 100644
--- a/tensorflow/python/training/slot_creator_test.py
+++ b/tensorflow/python/training/slot_creator_test.py
@@ -32,7 +32,7 @@
 class SlotCreatorTest(test.TestCase):
 
   def testCreateSlotFromVariable(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable([1.0, 2.5], name="var")
       slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
 
@@ -44,7 +44,7 @@
       self.assertAllEqual([1.0, 2.5], slot.eval())
 
   def testCreateSlotFromTensor(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant([1.0, 2.5], name="const")
       slot = slot_creator.create_slot(v, v * 2, name="slot")
 
@@ -56,7 +56,7 @@
       self.assertAllEqual([2.0, 5.0], slot.eval())
 
   def testCreateZerosSlotFromVariable(self):
-    with self.test_session():
+    with self.cached_session():
       v = variables.Variable([1.0, 2.5], name="var")
       with ops.control_dependencies(None):
         slot = slot_creator.create_zeros_slot(
@@ -70,7 +70,7 @@
       self.assertAllEqual([0.0, 0.0], slot.eval())
 
   def testCreateZerosSlotFromDynamicShapedVariable(self):
-    with self.test_session():
+    with self.cached_session():
       dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
       dyn_shape = array_ops.placeholder_with_default(dyn_shape,
                                                      shape=[None])
@@ -91,7 +91,7 @@
       self.assertAllEqual([0.0, 0.0], slot.eval())
 
   def testCreateZerosSlotFromTensor(self):
-    with self.test_session():
+    with self.cached_session():
       v = constant_op.constant([1.0, 2.5], name="const")
       with ops.control_dependencies(None):
         slot = slot_creator.create_zeros_slot(v, name="slot")
@@ -104,7 +104,7 @@
       self.assertAllEqual([0.0, 0.0], slot.eval())
 
   def testCreateZerosSlotFromDynamicShapedTensor(self):
-    with self.test_session():
+    with self.cached_session():
       v = random_ops.random_uniform([2], dtype=dtypes.float64)
       v = array_ops.placeholder_with_default(v, shape=[None], name="const")
       with ops.control_dependencies(None):
@@ -120,7 +120,7 @@
 
   def testCreateSlotFromVariableRespectsScope(self):
     # See discussion on #2740.
-    with self.test_session():
+    with self.cached_session():
       with variable_scope.variable_scope("scope"):
         v = variables.Variable([1.0, 2.5], name="var")
         slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index 71ed880..caf6eba 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -795,7 +795,7 @@
 
     self.assertRaises(StopIteration, lambda: next(rr))
     # There should be a checkpoint file with the variable "foo"
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([10.10], name="foo")
       sav = saver_lib.Saver([v])
       sav.restore(sess, save_path)
@@ -859,14 +859,14 @@
     self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
     self.assertRaises(StopIteration, lambda: next(rr))
     # There should be a checkpoint file with the variable "foo"
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       v = variables.Variable([-12], name="global_step")
       sav = saver_lib.Saver([v])
       sav.restore(sess, save_path)
       self.assertEqual(123, v.eval()[0])
 
   def testNoQueueRunners(self):
-    with ops.Graph().as_default(), self.test_session() as sess:
+    with ops.Graph().as_default(), self.cached_session() as sess:
       sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
       self.assertEqual(0, len(sv.start_queue_runners(sess)))
       sv.stop()
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index c0dd46b..bea9bb6 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -41,6 +41,7 @@
         "old_vocab",
         "old_vocab_size",
         "backup_initializer",
+        "axis",
     ])):
   """Vocabulary information for warm-starting.
 
@@ -62,6 +63,42 @@
     backup_initializer: [Optional] A variable initializer used for variables
       corresponding to new vocabulary entries and OOV. If not provided, these
       entries will be zero-initialized.
+    axis: [Optional] Denotes what axis the vocabulary corresponds to.  The
+      default, 0, corresponds to the most common use case (embeddings or
+      linear weights for binary classification / regression).  An axis of 1
+      could be used for warm-starting output layers with class vocabularies.
+
+      For example:
+
+      embeddings_vocab_info = tf.VocabInfo(
+          new_vocab='embeddings_vocab',
+          new_vocab_size=100,
+          num_oov_buckets=1,
+          old_vocab='pretrained_embeddings_vocab',
+          old_vocab_size=10000,
+          backup_initializer=tf.truncated_normal_initializer(
+              mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
+          axis=0)
+
+      softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
+          new_vocab='class_vocab',
+          new_vocab_size=5,
+          num_oov_buckets=0,  # No OOV for classes.
+          old_vocab='old_class_vocab',
+          old_vocab_size=8,
+          backup_initializer=tf.glorot_uniform_initializer(),
+          axis=1)
+
+      softmax_output_layer_bias_vocab_info = tf.VocabInfo(
+          new_vocab='class_vocab',
+          new_vocab_size=5,
+          num_oov_buckets=0,  # No OOV for classes.
+          old_vocab='old_class_vocab',
+          old_vocab_size=8,
+          backup_initializer=tf.zeros_initializer(),
+          axis=0)
+
+      Currently, only axis=0 and axis=1 are supported.
   """
 
   def __new__(cls,
@@ -70,7 +107,12 @@
               num_oov_buckets,
               old_vocab,
               old_vocab_size=-1,
-              backup_initializer=None):
+              backup_initializer=None,
+              axis=0):
+    if axis != 0 and axis != 1:
+      raise ValueError("The only supported values for the axis argument are 0 "
+                       "and 1.  Provided axis: {}".format(axis))
+
     return super(VocabInfo, cls).__new__(
         cls,
         new_vocab,
@@ -79,6 +121,7 @@
         old_vocab,
         old_vocab_size,
         backup_initializer,
+        axis,
     )
 
 
@@ -149,7 +192,8 @@
                                previous_vocab_size=-1,
                                current_oov_buckets=0,
                                prev_tensor_name=None,
-                               initializer=None):
+                               initializer=None,
+                               axis=0):
   """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
 
   Use this method when the `var` is backed by vocabulary. This method stitches
@@ -180,6 +224,7 @@
       None, we lookup tensor with same name as given `var`.
     initializer: Variable initializer to be used for missing entries.  If None,
       missing entries will be zero-initialized.
+    axis: Axis of the variable that the provided vocabulary corresponds to.
 
   Raises:
     ValueError: If required args are not provided.
@@ -204,6 +249,8 @@
     # Assume tensor name remains the same.
     prev_tensor_name = _infer_var_name(var)
 
+  # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
+  total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
   for v in var:
     v_shape = v.get_shape().as_list()
     slice_info = v._get_save_slice_info()
@@ -213,19 +260,45 @@
           full_shape=slice_info.full_shape,
           var_offset=slice_info.var_offset)
 
-    # TODO(eddz): Support cases where class vocabularies need remapping too.
+    if axis == 0:
+      new_row_vocab_size = current_vocab_size
+      new_col_vocab_size = v_shape[1]
+      old_row_vocab_size = previous_vocab_size
+      old_row_vocab_file = prev_vocab_path
+      new_row_vocab_file = current_vocab_path
+      old_col_vocab_file = None
+      new_col_vocab_file = None
+      num_row_oov_buckets = current_oov_buckets
+      num_col_oov_buckets = 0
+    elif axis == 1:
+      # Note that we must compute this value across all partitions, whereas
+      # in the axis = 0 case, we can simply use v_shape[1] because we don't
+      # allow partitioning across axis = 1.
+      new_row_vocab_size = total_v_first_axis
+      new_col_vocab_size = current_vocab_size
+      old_row_vocab_size = -1
+      old_row_vocab_file = None
+      new_row_vocab_file = None
+      old_col_vocab_file = prev_vocab_path
+      new_col_vocab_file = current_vocab_path
+      num_row_oov_buckets = 0
+      num_col_oov_buckets = current_oov_buckets
+    else:
+      raise ValueError("The only supported values for the axis argument are 0 "
+                       "and 1.  Provided axis: {}".format(axis))
+
     init = checkpoint_ops._load_and_remap_matrix_initializer(
         ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
         old_tensor_name=prev_tensor_name,
-        new_row_vocab_size=current_vocab_size,
-        new_col_vocab_size=v_shape[1],
-        old_row_vocab_size=previous_vocab_size,
-        old_row_vocab_file=prev_vocab_path,
-        new_row_vocab_file=current_vocab_path,
-        old_col_vocab_file=None,
-        new_col_vocab_file=None,
-        num_row_oov_buckets=current_oov_buckets,
-        num_col_oov_buckets=0,
+        new_row_vocab_size=new_row_vocab_size,
+        new_col_vocab_size=new_col_vocab_size,
+        old_row_vocab_size=old_row_vocab_size,
+        old_row_vocab_file=old_row_vocab_file,
+        new_row_vocab_file=new_row_vocab_file,
+        old_col_vocab_file=old_col_vocab_file,
+        new_col_vocab_file=new_col_vocab_file,
+        num_row_oov_buckets=num_row_oov_buckets,
+        num_col_oov_buckets=num_col_oov_buckets,
         initializer=initializer)
     new_init_val = ops.convert_to_tensor(
         init(shape=v_shape, partition_info=partition_info))
@@ -374,7 +447,8 @@
           previous_vocab_size=vocab_info.old_vocab_size,
           current_oov_buckets=vocab_info.num_oov_buckets,
           prev_tensor_name=prev_var_name,
-          initializer=vocab_info.backup_initializer)
+          initializer=vocab_info.backup_initializer,
+          axis=vocab_info.axis)
     else:
       # For the special value of vars_to_warm_start = None,
       # we only warm-start variables with explicitly specified vocabularies.
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 70a84bc3..6c860cd 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -107,7 +107,7 @@
             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
         ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+        self.assertAllClose(prev_val, fruit_weights.eval(sess))
 
   def testWarmStartVarPrevVarPartitioned(self):
     _, weights = self._create_prev_run_var(
@@ -123,7 +123,7 @@
             "fruit_weights", initializer=[[0.], [0.], [0.], [0.]])
         ws_util._warm_start_var(fruit_weights, self.get_temp_dir())
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual(prev_val, fruit_weights.eval(sess))
+        self.assertAllClose(prev_val, fruit_weights.eval(sess))
 
   def testWarmStartVarCurrentVarPartitioned(self):
     _, prev_val = self._create_prev_run_var(
@@ -143,7 +143,7 @@
         fruit_weights = fruit_weights._get_variable_list()
         new_val = np.concatenate(
             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
-        self.assertAllEqual(prev_val, new_val)
+        self.assertAllClose(prev_val, new_val)
 
   def testWarmStartVarBothVarsPartitioned(self):
     _, weights = self._create_prev_run_var(
@@ -170,7 +170,7 @@
         fruit_weights = fruit_weights._get_variable_list()
         new_val = np.concatenate(
             [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0)
-        self.assertAllEqual(prev_val, new_val)
+        self.assertAllClose(prev_val, new_val)
 
   def testWarmStartVarWithVocab(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
@@ -189,9 +189,34 @@
         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
                                            self.get_temp_dir(), prev_vocab_path)
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
+  def testWarmStartVarWithColumnVocab(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
   def testWarmStartVarWithVocabConstrainedOldVocabSize(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -215,7 +240,7 @@
             previous_vocab_size=2)
         sess.run(variables.global_variables_initializer())
         # Old vocabulary limited to ['apple', 'banana'].
-        self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]],
+        self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
   def testWarmStartVarWithVocabPrevVarPartitioned(self):
@@ -238,9 +263,36 @@
         ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5,
                                            self.get_temp_dir(), prev_vocab_path)
         sess.run(variables.global_variables_initializer())
-        self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]],
+        self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]],
                             fruit_weights.eval(sess))
 
+  def testWarmStartVarWithColumnVocabPrevVarPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        shape=[4, 2],
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+        partitioner=lambda shape, dtype: [2, 1])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.],
+                             [2.3, 2., 0.]], fruit_output_layer.eval(sess))
+
   def testWarmStartVarWithVocabCurrentVarPartitioned(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -269,11 +321,43 @@
         self.assertTrue(
             isinstance(fruit_weights, variables.PartitionedVariable))
         fruit_weights_vars = fruit_weights._get_variable_list()
-        self.assertAllEqual([[2.], [1.5], [1.]],
+        self.assertAllClose([[2.], [1.5], [1.]],
                             fruit_weights_vars[0].eval(sess))
-        self.assertAllEqual([[0.5], [0.], [0.]],
+        self.assertAllClose([[0.5], [0.], [0.]],
                             fruit_weights_vars[1].eval(sess))
 
+  def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            shape=[4, 3],
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]],
+            partitioner=lambda shape, dtype: [2, 1])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertTrue(
+            isinstance(fruit_output_layer, variables.PartitionedVariable))
+        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+                            fruit_output_layer_vars[0].eval(sess))
+        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+                            fruit_output_layer_vars[1].eval(sess))
+
   def testWarmStartVarWithVocabBothVarsPartitioned(self):
     prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"],
                                         "old_vocab")
@@ -301,11 +385,45 @@
         self.assertTrue(
             isinstance(fruit_weights, variables.PartitionedVariable))
         fruit_weights_vars = fruit_weights._get_variable_list()
-        self.assertAllEqual([[2.], [1.5], [1.]],
+        self.assertAllClose([[2.], [1.5], [1.]],
                             fruit_weights_vars[0].eval(sess))
-        self.assertAllEqual([[0.5], [0.], [0.]],
+        self.assertAllClose([[0.5], [0.], [0.]],
                             fruit_weights_vars[1].eval(sess))
 
+  def testWarmStartVarWithColumnVocabBothVarsPartitioned(self):
+    prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab")
+    self._create_prev_run_var(
+        "fruit_output_layer",
+        shape=[4, 2],
+        initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]],
+        partitioner=lambda shape, dtype: [2, 1])
+
+    # New vocab with elements in reverse order and one new element.
+    new_vocab_path = self._write_vocab(["orange", "apple", "banana"],
+                                       "new_vocab")
+    # New session and new graph.
+    with ops.Graph().as_default() as g:
+      with self.test_session(graph=g) as sess:
+        fruit_output_layer = variable_scope.get_variable(
+            "fruit_output_layer",
+            shape=[4, 3],
+            initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
+                         [0., 0., 0.]],
+            partitioner=lambda shape, dtype: [2, 1])
+        ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path,
+                                           current_vocab_size=3,
+                                           prev_ckpt=self.get_temp_dir(),
+                                           prev_vocab_path=prev_vocab_path,
+                                           axis=1)
+        sess.run(variables.global_variables_initializer())
+        self.assertTrue(
+            isinstance(fruit_output_layer, variables.PartitionedVariable))
+        fruit_output_layer_vars = fruit_output_layer._get_variable_list()
+        self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]],
+                            fruit_output_layer_vars[0].eval(sess))
+        self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]],
+                            fruit_output_layer_vars[1].eval(sess))
+
   def testWarmStart_ListOfVariables(self):
     # Save checkpoint from which to warm-start.
     _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1],
@@ -1015,7 +1133,7 @@
 
     # Unused variable names raises ValueError.
     with ops.Graph().as_default():
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         x = variable_scope.get_variable(
             "x",
             shape=[4, 1],
diff --git a/tensorflow/python/util/memory.py b/tensorflow/python/util/memory.py
new file mode 100644
index 0000000..e78f6d5
--- /dev/null
+++ b/tensorflow/python/util/memory.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions related to Python memory management."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# TODO(b/115366440): Delete this function when a custom OrderedDict is added
+def dismantle_ordered_dict(ordered_dict):
+  """Remove reference cycle in OrderedDict `ordered_dict`.
+
+  Helpful for making sure the garbage collector doesn't need to run after
+  using an OrderedDict.
+
+  Args:
+    ordered_dict: A `OrderedDict` object to destroy. This object is unusable
+      after this function runs.
+  """
+  # OrderedDict, makes a simple reference loop
+  # and hides it in an __attribute in some Python versions. We don't need to
+  # throw an error if we can't find it, but if we do find it we can break the
+  # loop to avoid creating work for the garbage collector.
+  problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None)  # pylint: disable=protected-access
+  if problematic_cycle:
+    try:
+      del problematic_cycle[0][:]
+    except TypeError:
+      # This is probably not one of the problematic Python versions. Continue
+      # with the rest of our cleanup.
+      pass
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2369eb6..ef50313 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -461,7 +461,7 @@
         inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
     }
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       output_np = sess.run(output, feed_dict=feed_dict)
     self.assertAllClose(output_np[0],
                         feed_dict[inp_a][0] + feed_dict[inp_b][0])
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 778121e..967c872 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -325,6 +325,11 @@
   return _inspect.isfunction(tf_decorator.unwrap(object)[1])
 
 
+def isgenerator(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.isgenerator."""
+  return _inspect.isgenerator(tf_decorator.unwrap(object)[1])
+
+
 def ismethod(object):  # pylint: disable=redefined-builtin
   """TFDecorator-aware replacement for inspect.ismethod."""
   return _inspect.ismethod(tf_decorator.unwrap(object)[1])
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index 16fa1f5..fedbe1d 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -106,7 +106,7 @@
     def return_const(value):
       return constant_op.constant(value, name='blah3')
     with reroute_error() as (error, _):
-      with self.test_session():
+      with self.cached_session():
         return_const(0.0)
         # Creating another op and executing it does not mark the
         # unused op as being "used".
@@ -124,7 +124,8 @@
     @tf_should_use.should_use_result
     def return_const(value):
       return constant_op.constant(value, name='blah3')
-    with self.test_session():
+
+    with self.cached_session():
       return_const(0.0).mark_used()
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6d336ac..104a615 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -104,9 +104,36 @@
 %unignore tensorflow::swig::Flatten;
 %noexception tensorflow::swig::Flatten;
 
+%feature("docstring") tensorflow::swig::IsSequenceForData
+"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
+
+NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
+which *does* treat a Python list as a sequence. For ergonomic
+reasons, `tf.data` users would prefer to treat lists as
+implicit `tf.Tensor` objects, and dicts as (nested) sequences.
+
+Args:
+  seq: an input sequence.
+
+Returns:
+  True if the sequence is a not a string or list and is a
+  collections.Sequence.
+"""
 %unignore tensorflow::swig::IsSequenceForData;
 %noexception tensorflow::swig::IsSequenceForData;
 
+%feature("docstring") tensorflow::swig::FlattenForData
+"""Returns a flat sequence from a given nested structure.
+
+If `nest` is not a sequence, this returns a single-element list: `[nest]`.
+
+Args:
+  nest: an arbitrarily nested structure or a scalar object.
+    Note, numpy arrays are considered scalars.
+
+Returns:
+  A Python list, the flattened version of the input.
+"""
 %unignore tensorflow::swig::FlattenForData;
 %noexception tensorflow::swig::FlattenForData;
 
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 7f851e3..f25ed70 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -41,6 +41,7 @@
 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
 
 #include <complex>
+#include <vector>
 
 #include "tensorflow/stream_executor/host_or_device_scalar.h"
 #include "tensorflow/stream_executor/lib/array_slice.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 207f22c9..3c533c7 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3275,6 +3275,26 @@
         "This configuration potentially produces incorrect results.");
   }());
 
+  // Zero out the result buffer for strided conv backward filter for NHWC
+  // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
+  // zeroed.
+  //
+  // This wrong result caused by the bug is very flaky. It needs to be run for
+  // up to 20 times to produce a mismatch.
+  //
+  // TODO(timshen): add a nvbugs link.
+  if (CUDNN_VERSION >= 7100 &&
+      algorithm_config.algorithm().algo_id() ==
+          CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
+      cudnn_type == CUDNN_DATA_HALF &&
+      input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+      filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
+      output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+      (convolution_descriptor.vertical_filter_stride() > 1 ||
+       convolution_descriptor.horizontal_filter_stride() > 1)) {
+    stream->ThenMemZero(backward_filter_data, backward_filter_data->size());
+  }
+
   RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
       cudnn.handle(),
       /*alpha=*/alpha,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e..10bf006 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@
 #include <atomic>
 #include <utility>
 
+#include "tensorflow/core/util/env_var.h"
 #include "tensorflow/stream_executor/blas.h"
 #include "tensorflow/stream_executor/fft.h"
 #include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@
   CheckPlatformKindIsValid(platform_kind);
 }
 
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+  int64 value;
+  SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+                                              0, &value));
+  return value * (1ll << 20);
+}
+
 StreamExecutor::StreamExecutor(
     const Platform *platform,
     std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@
       background_threads_(new port::ThreadPool(
           port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
       live_stream_count_(0),
-      tracing_enabled_(false) {
+      tracing_enabled_(false),
+      mem_alloc_bytes_(0),
+      memory_limit_bytes_(GetMemoryLimitBytes()) {
   if (port::Lowercase(platform_->Name()) == "cuda") {
     platform_kind_ = PlatformKind::kCuda;
   } else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@
 }
 
 void *StreamExecutor::Allocate(uint64 size) {
+  if (memory_limit_bytes_ > 0 &&
+      mem_alloc_bytes_ + size > memory_limit_bytes_) {
+    LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+                 << device_ordinal_
+                 << " within provided limit. [used=" << mem_alloc_bytes_
+                 << ", limit=" << memory_limit_bytes_ << "]";
+    return nullptr;
+  }
   void *buf = implementation_->Allocate(size);
   VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
           << buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@
     mutex_lock lock(mu_);
     mem_allocs_[opaque] = AllocRecord{
         bytes, ""};
+    mem_alloc_bytes_ += bytes;
   }
 }
 
@@ -789,6 +810,7 @@
       LOG(ERROR) << "Deallocating unknown pointer: "
                  << port::Printf("0x%p", opaque);
     } else {
+      mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
       mem_allocs_.erase(opaque);
     }
   }
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 437f298..d04025b 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -699,6 +699,13 @@
   // The set of TraceListeners registered for this StreamExecutor.
   std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
 
+  // Allocated memory in bytes.
+  int64 mem_alloc_bytes_;
+
+  // Memory limit in bytes. Value less or equal to 0 indicates there is no
+  // limit.
+  int64 memory_limit_bytes_;
+
   SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
 };
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd..15e0ab7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
   }
   member_method {
     name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index eb41dee..9f6dcd8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@
       type: TYPE_STRING
     }
     field {
-      name: "client_handles_error_formatting"
-      number: 2
-      label: LABEL_OPTIONAL
-      type: TYPE_BOOL
-    }
-    field {
       name: "executor_type"
       number: 3
       label: LABEL_OPTIONAL
       type: TYPE_STRING
     }
+    reserved_range {
+      start: 2
+      end: 3
+    }
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index e565b90..f3a5151 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@
         type: TYPE_STRING
       }
       field {
-        name: "client_handles_error_formatting"
-        number: 2
-        label: LABEL_OPTIONAL
-        type: TYPE_BOOL
-      }
-      field {
         name: "executor_type"
         number: 3
         label: LABEL_OPTIONAL
         type: TYPE_STRING
       }
+      reserved_range {
+        start: 2
+        end: 3
+      }
     }
   }
 }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
index cbf6554..2f4257a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+    argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
   }
   member_method {
     name: "gradient"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279..39ff336 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
   }
   member_method {
     name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 834f095..8774542 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4..6dd4636 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095..35b7105 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 587829a..8ae370a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94..b6942cb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
   member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
index 24a58fb..f06e798 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@
   }
   member_method {
     name: "input_layer"
-    argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+    argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index d843194..0869de0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -159,7 +159,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -219,7 +219,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index b8e9bac..20f39fa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -164,7 +164,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -228,7 +228,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 472b981..4011719 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -159,7 +159,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -219,7 +219,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 937516e..8a12ac1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -164,7 +164,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -228,7 +228,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000..e7e7d28
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+  is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "is_running"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "start"
+    argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+  }
+  member_method {
+    name: "stop"
+    argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index 4d7a151..81b91d2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "OrderedEnqueuer"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "Progbar"
     mtype: "<type \'type\'>"
   }
@@ -45,6 +49,10 @@
     argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
   }
   member_method {
+    name: "get_source_inputs"
+    argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
     name: "multi_gpu_model"
     argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
   }
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716..614ba42 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1..39b946b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
   member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd..15e0ab7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
   }
   member_method {
     name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
index eb41dee..9f6dcd8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt
@@ -9,16 +9,14 @@
       type: TYPE_STRING
     }
     field {
-      name: "client_handles_error_formatting"
-      number: 2
-      label: LABEL_OPTIONAL
-      type: TYPE_BOOL
-    }
-    field {
       name: "executor_type"
       number: 3
       label: LABEL_OPTIONAL
       type: TYPE_STRING
     }
+    reserved_range {
+      start: 2
+      end: 3
+    }
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
index e565b90..f3a5151 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt
@@ -132,17 +132,15 @@
         type: TYPE_STRING
       }
       field {
-        name: "client_handles_error_formatting"
-        number: 2
-        label: LABEL_OPTIONAL
-        type: TYPE_BOOL
-      }
-      field {
         name: "executor_type"
         number: 3
         label: LABEL_OPTIONAL
         type: TYPE_STRING
       }
+      reserved_range {
+        start: 2
+        end: 3
+      }
     }
   }
 }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt
deleted file mode 100644
index 260c796..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.FixedLengthRecordReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.FixedLengthRecordReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
index cbf6554..2f4257a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+    argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
   }
   member_method {
     name: "gradient"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt
deleted file mode 100644
index 2eda320..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.IdentityReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.IdentityReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt
deleted file mode 100644
index f9b7e9b..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.LMDBReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.LMDBReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt
deleted file mode 100644
index f6a3ce7..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt
+++ /dev/null
@@ -1,45 +0,0 @@
-path: "tensorflow.ReaderBase"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'reader_ref\', \'supports_serialize\'], varargs=None, keywords=None, defaults=[\'False\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279..39ff336 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+    argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
   }
   member_method {
     name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt
deleted file mode 100644
index cdf7937..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.TFRecordReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.TFRecordReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt
deleted file mode 100644
index e9779f0..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.TextLineReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.TextLineReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'skip_header_lines\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt
deleted file mode 100644
index 4ac7598..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.WholeFileReader"
-tf_class {
-  is_instance: "<class \'tensorflow.python.ops.io_ops.WholeFileReader\'>"
-  is_instance: "<class \'tensorflow.python.ops.io_ops.ReaderBase\'>"
-  is_instance: "<type \'object\'>"
-  member {
-    name: "reader_ref"
-    mtype: "<type \'property\'>"
-  }
-  member {
-    name: "supports_serialize"
-    mtype: "<type \'property\'>"
-  }
-  member_method {
-    name: "__init__"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_records_produced"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "num_work_units_completed"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read"
-    argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "read_up_to"
-    argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "reset"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "restore_state"
-    argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "serialize_state"
-    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 834f095..8774542 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4..6dd4636 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095..35b7105 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 587829a..8ae370a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@
   }
   member_method {
     name: "interleave"
-    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+    argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
   }
   member_method {
     name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
index 5301b94..b6942cb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt
@@ -4,6 +4,10 @@
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
   member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
index 24a58fb..f06e798 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@
   }
   member_method {
     name: "input_layer"
-    argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+    argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index d843194..0869de0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -159,7 +159,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -219,7 +219,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index b8e9bac..20f39fa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -164,7 +164,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -228,7 +228,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 472b981..4011719 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -159,7 +159,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -219,7 +219,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 937516e..8a12ac1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@
   }
   member_method {
     name: "evaluate"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "evaluate_generator"
@@ -164,7 +164,7 @@
   }
   member_method {
     name: "fit"
-    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "fit_generator"
@@ -228,7 +228,7 @@
   }
   member_method {
     name: "predict"
-    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
   }
   member_method {
     name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
new file mode 100644
index 0000000..e7e7d28
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.-ordered-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.OrderedEnqueuer"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.utils.data_utils.OrderedEnqueuer\'>"
+  is_instance: "<class \'tensorflow.python.keras.utils.data_utils.SequenceEnqueuer\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'sequence\', \'use_multiprocessing\', \'shuffle\'], varargs=None, keywords=None, defaults=[\'False\', \'False\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "is_running"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "start"
+    argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+  }
+  member_method {
+    name: "stop"
+    argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index 4d7a151..81b91d2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -13,6 +13,10 @@
     mtype: "<type \'type\'>"
   }
   member {
+    name: "OrderedEnqueuer"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "Progbar"
     mtype: "<type \'type\'>"
   }
@@ -45,6 +49,10 @@
     argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
   }
   member_method {
+    name: "get_source_inputs"
+    argspec: "args=[\'tensor\', \'layer\', \'node_index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
     name: "multi_gpu_model"
     argspec: "args=[\'model\', \'gpus\', \'cpu_merge\', \'cpu_relocation\'], varargs=None, keywords=None, defaults=[\'True\', \'False\'], "
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 7d45ea2..9332e16 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -61,10 +61,6 @@
     mtype: "<type \'type\'>"
   }
   member {
-    name: "FixedLengthRecordReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "GIT_VERSION"
     mtype: "<type \'str\'>"
   }
@@ -109,10 +105,6 @@
     mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
   }
   member {
-    name: "IdentityReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "IndexedSlices"
     mtype: "<type \'type\'>"
   }
@@ -121,10 +113,6 @@
     mtype: "<type \'type\'>"
   }
   member {
-    name: "LMDBReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "LogMessage"
     mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
   }
@@ -177,10 +165,6 @@
     mtype: "<type \'type\'>"
   }
   member {
-    name: "ReaderBase"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "RegisterGradient"
     mtype: "<type \'type\'>"
   }
@@ -225,10 +209,6 @@
     mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
   }
   member {
-    name: "TFRecordReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "Tensor"
     mtype: "<type \'type\'>"
   }
@@ -245,10 +225,6 @@
     mtype: "<type \'type\'>"
   }
   member {
-    name: "TextLineReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "VERSION"
     mtype: "<type \'str\'>"
   }
@@ -273,10 +249,6 @@
     mtype: "<class \'enum.EnumMeta\'>"
   }
   member {
-    name: "WholeFileReader"
-    mtype: "<type \'type\'>"
-  }
-  member {
     name: "app"
     mtype: "<type \'module\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716..614ba42 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
index 4ce7cb1..39b946b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
@@ -4,6 +4,10 @@
   is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
   is_instance: "<type \'tuple\'>"
   member {
+    name: "axis"
+    mtype: "<type \'property\'>"
+  }
+  member {
     name: "backup_initializer"
     mtype: "<type \'property\'>"
   }
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index c35e254..b21dabb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -249,14 +249,6 @@
     argspec: "args=[\'supervisor\', \'train_step_fn\', \'args\', \'kwargs\', \'master\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\'], "
   }
   member_method {
-    name: "batch"
-    argspec: "args=[\'tensors\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "batch_join"
-    argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "checkpoint_exists"
     argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None"
   }
@@ -317,10 +309,6 @@
     argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "input_producer"
-    argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "inverse_time_decay"
     argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -329,10 +317,6 @@
     argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
-    name: "limit_epochs"
-    argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
-  }
-  member_method {
     name: "linear_cosine_decay"
     argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], "
   }
@@ -353,22 +337,6 @@
     argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
-    name: "maybe_batch"
-    argspec: "args=[\'tensors\', \'keep_input\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "maybe_batch_join"
-    argspec: "args=[\'tensors_list\', \'keep_input\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "maybe_shuffle_batch"
-    argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "maybe_shuffle_batch_join"
-    argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "natural_exp_decay"
     argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
   }
@@ -385,10 +353,6 @@
     argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
   }
   member_method {
-    name: "range_input_producer"
-    argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "remove_checkpoint"
     argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], "
   }
@@ -409,22 +373,6 @@
     argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
-    name: "shuffle_batch"
-    argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "shuffle_batch_join"
-    argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "slice_input_producer"
-    argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
-  }
-  member_method {
-    name: "string_input_producer"
-    argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], "
-  }
-  member_method {
     name: "summary_iterator"
     argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
   }
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8764409..4efa4a9 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -15,7 +15,10 @@
 
 py_test(
     name = "api_compatibility_test",
-    srcs = ["api_compatibility_test.py"],
+    srcs = [
+        "api_compatibility_test.py",
+        "//tensorflow:tf_python_api_gen_v2",
+    ],
     data = [
         "//tensorflow/tools/api/golden:api_golden_v1",
         "//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc..d06c7f2 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@
 import unittest
 
 import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
 
 from google.protobuf import message
 from google.protobuf import text_format
@@ -173,7 +174,7 @@
         verbose_diff_message = diff_message
       else:
         # Do not truncate diff
-        self.maxDiffs = None  # pylint: disable=invalid-name
+        self.maxDiff = None  # pylint: disable=invalid-name
         # Now we can run an actual proto diff.
         try:
           self.assertProtoEquals(expected_dict[key], actual_dict[key])
@@ -232,14 +233,14 @@
       return
     visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
-    traverse.traverse(tf.compat.v1, visitor)
+    traverse.traverse(tf_v2.compat.v1, visitor)
 
   def testNoSubclassOfMessageV2(self):
     if not hasattr(tf.compat, 'v2'):
       return
     visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
     visitor.do_not_descend_map['tf'].append('contrib')
-    traverse.traverse(tf.compat.v2, visitor)
+    traverse.traverse(tf_v2, visitor)
 
   def _checkBackwardsCompatibility(
       self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@
       sys.version_info.major == 2,
       'API compabitility test goldens are generated using python2.')
   def testAPIBackwardsCompatibilityV1(self):
-    if not hasattr(tf.compat, 'v1'):
-      return
     api_version = 1
     golden_file_pattern = os.path.join(
         resource_loader.get_root_dir_with_all_resources(),
         _KeyToFilePath('*', api_version))
     self._checkBackwardsCompatibility(
-        tf.compat.v1, golden_file_pattern, api_version)
+        tf_v2.compat.v1, golden_file_pattern, api_version)
 
   @unittest.skipUnless(
       sys.version_info.major == 2,
       'API compabitility test goldens are generated using python2.')
   def testAPIBackwardsCompatibilityV2(self):
-    if not hasattr(tf.compat, 'v2'):
-      return
     api_version = 2
     golden_file_pattern = os.path.join(
         resource_loader.get_root_dir_with_all_resources(),
         _KeyToFilePath('*', api_version))
     self._checkBackwardsCompatibility(
-        tf.compat.v2, golden_file_pattern, api_version)
+        tf_v2, golden_file_pattern, api_version,
+        additional_private_map={'tf.compat': ['v1']})
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu
index f05c7a4..a4cad4b 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu
@@ -30,3 +30,4 @@
 
 # Configure the build for our CUDA configuration.
 ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
new file mode 100644
index 0000000..a30858d
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -0,0 +1,83 @@
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 \
+#       --tag "gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04" .
+# $ docker push gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04
+#
+# TODO(klimek): Include clang in this image so we can also target clang
+# builds.
+
+FROM ubuntu:14.04
+LABEL maintainer="Manuel Klimek <klimek@google.com>"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates apt-transport-https gnupg-curl && \
+    rm -rf /var/lib/apt/lists/* && \
+    NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \
+    NVIDIA_GPGKEY_FPR=ae09fe4bbd223a84b2ccfce3f60f4b3d7fa2af80 && \
+    apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/7fa2af80.pub && \
+    apt-key adv --export --no-emit-version -a $NVIDIA_GPGKEY_FPR | tail -n +2 > cudasign.pub && \
+    echo "$NVIDIA_GPGKEY_SUM  cudasign.pub" | sha256sum -c --strict - && rm cudasign.pub && \
+    echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
+    echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list
+
+ENV CUDA_VERSION 9.0.176
+ENV CUDA_PKG_VERSION 9-0=$CUDA_VERSION-1
+ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
+ENV NCCL_VERSION 2.2.13
+ENV CUDNN_VERSION 7.2.1.38
+
+# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
+# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
+# correct way to pass the path to bfd-ld is to pass
+# -Wl,-rpath-link=/usr/local/cuda/lib64/stubs to all binaries transitively
+# depending on libcuda. Optimally, builds targeting cuda would do that
+# internally.
+ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64/stubs
+
+LABEL com.nvidia.volumes.needed="nvidia_driver"
+LABEL com.nvidia.cuda.version="${CUDA_VERSION}"
+LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+        cuda-cudart-$CUDA_PKG_VERSION \
+        cuda-libraries-$CUDA_PKG_VERSION \
+        cuda-cublas-9-0=9.0.176.4-1 \
+        libnccl2=$NCCL_VERSION-1+cuda9.0 \
+        cuda-libraries-dev-$CUDA_PKG_VERSION \
+        cuda-nvml-dev-$CUDA_PKG_VERSION \
+        cuda-minimal-build-$CUDA_PKG_VERSION \
+        cuda-command-line-tools-$CUDA_PKG_VERSION \
+        cuda-core-9-0=9.0.176.3-1 \
+        cuda-cublas-dev-9-0=9.0.176.4-1 \
+        libnccl-dev=$NCCL_VERSION-1+cuda9.0 \
+        libcudnn7-dev=$CUDNN_VERSION-1+cuda9.0 \
+        libcudnn7=$CUDNN_VERSION-1+cuda9.0 && \
+    ln -s cuda-9.0 /usr/local/cuda && \
+    apt-mark hold libnccl2 && \
+    apt-mark hold libcudnn7 libcudnn7-dev && \
+    rm -rf /var/lib/apt/lists/*
+
+RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
+    echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf
+
+# TODO(b/110903506): Provide a link to the SONAME of libcuda.so.
+# https://github.com/NVIDIA/nvidia-docker/issues/775
+RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
+# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
+# libnccl is resolved, delete this block.
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
+ && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+    add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_golang.sh
+
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index bbaf59c..4b762bf 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -76,7 +76,7 @@
 
 # Do not run tests with "no_pip" tag. If running GPU tests, also do not run
 # tests with no_pip_gpu tag.
-PIP_TEST_FILTER_TAG="-no_pip,-no_oss"
+PIP_TEST_FILTER_TAG="-no_pip,-no_oss,-benchmark-test"
 if [[ ${IS_OSS_SERIAL} == "1" ]]; then
   PIP_TEST_FILTER_TAG="$(echo "${PIP_TEST_FILTER_TAG}" | sed s/-no_oss//)"
   PIP_TEST_FILTER_TAG="${PIP_TEST_FILTER_TAG},oss_serial"
@@ -85,7 +85,7 @@
 fi
 
 if [[ ${IS_GPU} == "1" ]]; then
-  PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
+  PIP_TEST_FILTER_TAG="-no_gpu,-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
 fi
 if [[ ${IS_MAC} == "1" ]]; then
   PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 1d7d9df..cc09784c 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -86,7 +86,7 @@
 #                     When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX
 #                     options, as this will replace the two.
 #   TF_SKIP_CONTRIB_TESTS:
-#                     If set to any non-empty or non-0 value, will skipp running
+#                     If set to any non-empty or non-0 value, will skip running
 #                     contrib tests.
 #   TF_NIGHTLY:
 #                     If this run is being used to build the tf_nightly pip
@@ -127,11 +127,19 @@
 
 DO_DOCKER=1
 
-BAZEL_CMD="bazel test"
-BAZEL_BUILD_ONLY_CMD="bazel build"
-BAZEL_CLEAN_CMD="bazel clean"
 
-DEFAULT_BAZEL_CONFIGS=""
+# Helpful flags:
+# --test_summary=detailed: Tell us more about which targets are being built
+# --keep_going: Don't stop at the first failure; tell us all the failures
+# --build_tests_only: Don't build targets depended on by tests if the test is
+#                     disabled. Also saves some compilation time. Otherwise,
+#                     tries to build everything.
+BAZEL_TEST_FLAGS="--test_summary=detailed --build_tests_only --keep_going"
+BAZEL_BUILD_FLAGS="--keep_going"
+
+BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS}"
+BAZEL_BUILD_ONLY_CMD="bazel build ${BAZEL_BUILD_FLAGS}"
+BAZEL_CLEAN_CMD="bazel clean"
 
 PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh"
 PIP_TEST_TUTORIALS_FLAG="--test_tutorials"
@@ -148,9 +156,7 @@
 BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..."
 
 if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then
-  BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..."
-else
-  BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..."
+  BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/..."
 fi
 
 TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data"
@@ -389,7 +395,7 @@
 EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false"
 
 if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] &&
-   [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then 
+   [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
   BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD}
 fi
 
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 75da9bb..03a2a07 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -16,29 +16,25 @@
 #
 #
 # A script to run multiple GPU tests in parallel controlled with an environment
-# variable. This script will assume that when it runs, one of the locks are
-# already released. So the program calling this script is expected to make sure
-# that only $TF_GPU_COUNT processes are running at any gien time.
+# variable.
 #
 # Required environment variables:
-#     TF_GPU_COUNT = Number of GPUs available. This HAS TO BE IN SYNC with the
-#                    value of --local_test_jobs flag for bazel.
+#     TF_GPU_COUNT = Number of GPUs available.
 
-BASH_VER_MAJOR=$(echo ${BASH_VERSION} | cut -d '.' -f 1)
-BASH_VER_MINOR=$(echo ${BASH_VERSION} | cut -d '.' -f 2)
+TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-4}
+# We want to allow running one of the following configs:
+#  - 4 tests per GPU on k80
+#  - 8 tests per GPU on p100
+# p100 has minimum 12G memory. Therefore, we should limit each test to 1.5G.
+# To leave some room in case we want to run more tests in parallel in the
+# future and to use a rounder number, we set it to 1G.
+export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024
 
-if [[ ${BASH_VER_MAJOR} -lt 4 ]]; then
-  echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
-  exit 1
-elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then
-  echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
-  exit 1
-fi
-
-function is_absolute {
-  [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
-}
-
+# *******************************************************************
+#         This section of the script is needed to
+#         make things work on windows under msys.
+# *******************************************************************
 RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST"
 function rlocation() {
   if is_absolute "$1" ; then
@@ -55,29 +51,32 @@
 
 TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})"
 shift
+# *******************************************************************
 
-# Make sure /var/lock exists, this may not be true under MSYS
 mkdir -p /var/lock
-
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-
-for i in `seq 0 $((TF_GPU_COUNT-1))`; do
-  exec {lock_fd}>/var/lock/gpulock$i || exit 1
-  if flock -n "$lock_fd";
-  then
-    (
-      # This export only works within the brackets, so it is isolated to one
-      # single command.
-      export CUDA_VISIBLE_DEVICES=$i
-      echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
-      "$TEST_BINARY" $@
-    )
-    return_code=$?
-    flock -u "$lock_fd"
-    exit $return_code
-  fi
+# Try to acquire any of the TF_GPU_COUNT * TF_TESTS_PER_GPU
+# slots to run a test at.
+#
+# Prefer to allocate 1 test per GPU over 4 tests on 1 GPU.
+# So, we iterate over TF_TESTS_PER_GPU first.
+for j in `seq 0 $((TF_TESTS_PER_GPU-1))`; do
+  for i in `seq 0 $((TF_GPU_COUNT-1))`; do
+    exec {lock_fd}>/var/lock/gpulock${i}_${j} || exit 1
+    if flock -n "$lock_fd";
+    then
+      (
+        # This export only works within the brackets, so it is isolated to one
+        # single command.
+        export CUDA_VISIBLE_DEVICES=$i
+        echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
+        "$TEST_BINARY" $@
+      )
+      return_code=$?
+      flock -u "$lock_fd"
+      exit $return_code
+    fi
+  done
 done
 
 echo "Cannot find a free GPU to run the test $* on, exiting with failure..."
 exit 1
-
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 9640810..179fc42 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -67,6 +67,12 @@
     zip \
     zlib1g-dev
 
+apt-get update && \
+  apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
+  apt-get update && \
+  apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
+  apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
+
 # populate the database
 updatedb
 
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index af478ed..a9ae715 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -119,6 +119,8 @@
 pip3 install keras_applications==1.0.5 --no-deps
 pip2 install keras_preprocessing==1.0.3 --no-deps
 pip3 install keras_preprocessing==1.0.3 --no-deps
+pip2 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==2.8.0
 
 # Install last working version of setuptools.
 pip2 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 93ea0c3..37e6b51 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -87,6 +87,7 @@
 # Keras
 pip3.5 install keras_applications==1.0.5
 pip3.5 install keras_preprocessing==1.0.3
+pip3.5 install --upgrade h5py==2.8.0
 
 # Install last working version of setuptools.
 pip3.5 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 7a9eef7..7520ff7 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -99,6 +99,7 @@
 
 # Install last working version of setuptools.
 pip3 install --upgrade setuptools==39.1.0
+pip3 install --upgrade h5py==2.8.0
 
 # Keras
 pip3 install keras_applications==1.0.5
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
index 2a9f295..7be5f45 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
@@ -33,7 +33,7 @@
 # Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution
 # in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads
 # caused by executing multiple tests concurrently.
-bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
+bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \
     --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
     --config=mkl --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \
     //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index f958b3c..60c974c 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -52,6 +52,7 @@
   -e "PYTHON_BIN_PATH=/usr/bin/python" \
   -e "TF_NEED_HDFS=0" \
   -e "TF_NEED_CUDA=${TF_NEED_CUDA}" \
+  -e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \
   -e "TF_NEED_OPENCL_SYCL=0" \
   "${DOCKER_IMAGE}" \
   "/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh"
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 333a89d..c18f0d6 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -53,7 +53,7 @@
 
 # Setting default values to CUDA related environment variables
 export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0}
-export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0}
+export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7}
 export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7}
 export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
 export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py
index 216aa41..7e66ad8 100644
--- a/tensorflow/tools/compatibility/renames_v2.py
+++ b/tensorflow/tools/compatibility/renames_v2.py
@@ -65,6 +65,7 @@
     'tf.fft': 'tf.spectral.fft',
     'tf.floor': 'tf.math.floor',
     'tf.gather_nd': 'tf.manip.gather_nd',
+    'tf.GraphKeys.VARIABLES': 'tf.GraphKeys.GLOBAL_VARIABLES',
     'tf.greater': 'tf.math.greater',
     'tf.greater_equal': 'tf.math.greater_equal',
     'tf.ifft': 'tf.spectral.ifft',
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 01f37d8..35a74c9 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -35,7 +35,7 @@
   """
 
   def testArgRenames(self):
-    with self.test_session():
+    with self.cached_session():
 
       a = [[1., 2., 3.], [4., 5., 6.]]
       b = [[True, False, False], [False, True, True]]
@@ -98,7 +98,7 @@
           [[[1, 2]], [[3, 4]]])
 
   def testArgMinMax(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(
           tf.argmin([[1, 2, 3], [4, 1, 0]], dimension=1).eval(),
           [0, 2])
@@ -113,7 +113,7 @@
           [1, 0, 0])
 
   def testExpandAndSqueeze(self):
-    with self.test_session():
+    with self.cached_session():
 
       # TODO(aselle): sparse_split, sparse_reduce_sum,
       #  sparse_reduce_sum_sparse, reduce_join
@@ -140,7 +140,7 @@
           a)
 
   def testArithmeticRenames(self):
-    with self.test_session() as s:
+    with self.cached_session() as s:
       stuff = tf.split(1, 2, [[1, 2, 3, 4], [4, 5, 6, 7]])
       vals = s.run(stuff)
       self.assertAllEqual(vals,
@@ -164,7 +164,7 @@
       # ]
 
   def testBatchAndSvd(self):
-    with self.test_session():
+    with self.cached_session():
       mat = [[1., 2.], [2., 3.]]
       batched_mat = tf.expand_dims(mat, [0])
       result = tf.matmul(mat, mat).eval()
@@ -176,7 +176,7 @@
 
   def testCrossEntropy(self):
     # TODO(aselle): Test sparse_softmax_...
-    with self.test_session():
+    with self.cached_session():
       labels = [.8, .5, .2, .1]
       logits = [.9, .1, .3, .1]
       self.assertAllEqual(
@@ -191,7 +191,7 @@
               labels=labels, logits=logits).eval())
 
   def testVariables(self):
-    with self.test_session() as s:
+    with self.cached_session() as s:
 
       # make some variables
       _ = [tf.Variable([1, 2, 3], dtype=tf.float32),
@@ -201,7 +201,7 @@
       _ = [v.name for v in tf.local_variables()]
 
   def testSummaries(self):
-    with self.test_session() as s:
+    with self.cached_session() as s:
       var = tf.Variable([1, 2, 3], dtype=tf.float32)
       s.run(tf.initialize_all_variables())
       x, y = np.meshgrid(np.linspace(-10, 10, 256), np.linspace(-10, 10, 256))
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
index a49035a1..e5ca8d3 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
@@ -26,7 +26,7 @@
   """Test various APIs that have been changed in 2.0."""
 
   def testRenames(self):
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(1.04719755, tf.acos(0.5).eval())
       self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
 
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 9702430..38216ce 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 
 import argparse
+import functools
 
 from tensorflow.tools.compatibility import ast_edits
 from tensorflow.tools.compatibility import renames_v2
@@ -45,6 +46,29 @@
 
     # Specially handled functions.
     self.function_handle = {}
+    for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+                  "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+                  "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+                  "tf.train.cosine_decay_restarts",
+                  "tf.train.linear_cosine_decay",
+                  "tf.train.noisy_linear_cosine_decay"]:
+      self.function_handle[decay] = functools.partial(
+          self._learning_rate_decay_handler, decay_name=decay)
+
+  @staticmethod
+  def _learning_rate_decay_handler(file_edit_recorder, node, decay_name):
+    comment = ("ERROR: %s has been changed to return a callable instead of a "
+               "tensor when graph building, but its functionality remains "
+               "unchanged during eager execution (returns a callable like "
+               "before). The converter cannot detect and fix this reliably, so "
+               "you need to inspect this usage manually.\n") % decay_name
+    file_edit_recorder.add(
+        comment,
+        node.lineno,
+        node.col_offset,
+        decay_name,
+        decay_name,
+        error="%s requires manual check." % decay_name)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 57ac04d..3886c1e 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -63,6 +63,19 @@
     _, unused_report, unused_errors, new_text = self._upgrade(text)
     self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n")
 
+  def testLearningRateDecay(self):
+    for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+                  "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+                  "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+                  "tf.train.cosine_decay_restarts",
+                  "tf.train.linear_cosine_decay",
+                  "tf.train.noisy_linear_cosine_decay"]:
+
+      text = "%s(a, b)\n" % decay
+      _, unused_report, errors, new_text = self._upgrade(text)
+      self.assertEqual(text, new_text)
+      self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
+
 
 class TestUpgradeFiles(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
index c484c16..5996573 100644
--- a/tensorflow/tools/dockerfiles/README.md
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -2,8 +2,8 @@
 
 This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES
 MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from
-the files in `partials/` and the rules in `spec.yml`. See [the Maintaining
-section](#maintaining) for more information.
+the files in `partials/` and the rules in `spec.yml`. See [the Contributing
+section](#contributing) for more information.
 
 ## Building
 
@@ -34,13 +34,13 @@
 # User permissions (-u) are required if you use (-v).
 
 # CPU-based images
-$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
 
 # GPU-based images (set up nvidia-docker2 first)
-$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
 
 # Images with Jupyter run on port 8888, and needs a volume for notebooks
-$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf
+$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(pwd):/notebooks -it tf
 ```
 
 These images do not come with the TensorFlow source code -- but the development
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 483921f..1cd9cb7 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -36,23 +36,6 @@
 from tensorflow.tools.docs import py_guide_parser
 
 
-def _is_free_function(py_object, full_name, index):
-  """Check if input is a free function (and not a class- or static method)."""
-  if not tf_inspect.isfunction(py_object):
-    return False
-
-  # Static methods are functions to tf_inspect (in 2.7), so check if the parent
-  # is a class. If there is no parent, it's not a function.
-  if '.' not in full_name:
-    return False
-
-  parent_name = full_name.rsplit('.', 1)[0]
-  if tf_inspect.isclass(index[parent_name]):
-    return False
-
-  return True
-
-
 def write_docs(output_dir,
                parser_config,
                yaml_toc,
@@ -109,7 +92,7 @@
 
     # Methods and some routines are documented only as part of their class.
     if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or
-            _is_free_function(py_object, full_name, parser_config.index)):
+            parser.is_free_function(py_object, full_name, parser_config.index)):
       continue
 
     sitepath = os.path.join('api_docs/python',
@@ -548,6 +531,13 @@
         help='The path from the site-root to api_docs'
              'directory for this project')
 
+    self.argument_parser.add_argument(
+        '--api_cache_out_path',
+        type=str,
+        default=None,
+        help='Path to store a json-serialized api-index, so links can be '
+        'inserted into docs without rebuilding the api_docs')
+
   def add_output_dir_argument(self):
     self.argument_parser.add_argument(
         '--output_dir',
@@ -648,6 +638,9 @@
     visitor = self.run_extraction()
     reference_resolver = self.make_reference_resolver(visitor, doc_index)
 
+    if getattr(flags, 'api_cache_out_path', None):
+      reference_resolver.to_json_file(flags.api_cache_out_path)
+
     # Build the guide_index for the api_docs back links.
     root_title = getattr(flags, 'root_title', 'TensorFlow')
     guide_index = _build_guide_index(
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 997afc6..83b4bf8 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -35,6 +35,28 @@
 from tensorflow.tools.docs import doc_controls
 
 
+def is_free_function(py_object, full_name, index):
+  """Check if input is a free function (and not a class- or static method).
+
+  Args:
+    py_object: The the object in question.
+    full_name: The full name of the object, like `tf.module.symbol`.
+    index: The {full_name:py_object} dictionary for the public API.
+
+  Returns:
+    True if the obeject is a stand-alone function, and not part of a class
+    definition.
+  """
+  if not tf_inspect.isfunction(py_object):
+    return False
+
+  parent_name = full_name.rsplit('.', 1)[0]
+  if tf_inspect.isclass(index[parent_name]):
+    return False
+
+  return True
+
+
 # A regular expression capturing a python identifier.
 IDENTIFIER_RE = r'[a-zA-Z_]\w*'
 
@@ -74,7 +96,7 @@
     return self._errors == other._errors  # pylint: disable=protected-access
 
 
-def documentation_path(full_name):
+def documentation_path(full_name, is_fragment=False):
   """Returns the file path for the documentation for the given API symbol.
 
   Given the fully qualified name of a library symbol, compute the path to which
@@ -84,12 +106,22 @@
 
   Args:
     full_name: Fully qualified name of a library symbol.
-
+    is_fragment: If `False` produce a direct markdown link (`tf.a.b.c` -->
+      `tf/a/b/c.md`). If `True` produce fragment link, `tf.a.b.c` -->
+      `tf/a/b.md#c`
   Returns:
     The file path to which to write the documentation for `full_name`.
   """
-  dirs = full_name.split('.')
-  return os.path.join(*dirs) + '.md'
+  parts = full_name.split('.')
+  if is_fragment:
+    parts, fragment = parts[:-1], parts[-1]
+
+  result = os.path.join(*parts) + '.md'
+
+  if is_fragment:
+    result = result + '#' + fragment
+
+  return result
 
 
 def _get_raw_docstring(py_object):
@@ -136,8 +168,7 @@
       doc.
   """
 
-  def __init__(self, duplicate_of, doc_index, is_class, is_module,
-               py_module_names):
+  def __init__(self, duplicate_of, doc_index, is_fragment, py_module_names):
     """Initializes a Reference Resolver.
 
     Args:
@@ -145,15 +176,15 @@
         symbols.
       doc_index: A `dict` mapping symbol name strings to objects with `url`
         and `title` fields. Used to resolve @{$doc} references in docstrings.
-      is_class: A map from full names to bool for each symbol.
-      is_module: A map from full names to bool for each symbol.
+      is_fragment: A map from full names to bool for each symbol. If True the
+        object lives at a page fragment `tf.a.b.c` --> `tf/a/b#c`. If False
+        object has a page to itself: `tf.a.b.c` --> `tf/a/b/c`.
       py_module_names: A list of string names of Python modules.
     """
     self._duplicate_of = duplicate_of
     self._doc_index = doc_index
-    self._is_class = is_class
-    self._is_module = is_module
-    self._all_names = set(is_class.keys())
+    self._is_fragment = is_fragment
+    self._all_names = set(is_fragment.keys())
     self._py_module_names = py_module_names
 
     self.current_doc_full_name = None
@@ -180,21 +211,18 @@
     Returns:
       an instance of `ReferenceResolver` ()
     """
-    is_class = {
-        name: tf_inspect.isclass(visitor.index[name])
-        for name, obj in visitor.index.items()
-    }
+    is_fragment = {}
+    for name, obj in visitor.index.items():
+      has_page = (
+          tf_inspect.isclass(obj) or tf_inspect.ismodule(obj) or
+          is_free_function(obj, name, visitor.index))
 
-    is_module = {
-        name: tf_inspect.ismodule(visitor.index[name])
-        for name, obj in visitor.index.items()
-    }
+      is_fragment[name] = not has_page
 
     return cls(
         duplicate_of=visitor.duplicate_of,
         doc_index=doc_index,
-        is_class=is_class,
-        is_module=is_module,
+        is_fragment=is_fragment,
         **kwargs)
 
   @classmethod
@@ -210,6 +238,10 @@
     Args:
       filepath: The file path to write the json to.
     """
+    try:
+      os.makedirs(os.path.dirname(filepath))
+    except OSError:
+      pass
     json_dict = {}
     for key, value in self.__dict__.items():
       # Drop these two fields. `_doc_index` is not serializable. `_all_names` is
@@ -223,7 +255,7 @@
       json_dict[key.lstrip('_')] = value
 
     with open(filepath, 'w') as f:
-      json.dump(json_dict, f)
+      json.dump(json_dict, f, indent=2, sort_keys=True)
 
   def replace_references(self, string, relative_path_to_root):
     """Replace "@{symbol}" references with links to symbol's documentation page.
@@ -339,19 +371,7 @@
       raise TFDocsError(
           'Cannot make link to "%s": Not in index.' % master_name)
 
-    # If this is a member of a class, link to the class page with an anchor.
-    ref_path = None
-    if not (self._is_class[master_name] or self._is_module[master_name]):
-      idents = master_name.split('.')
-      if len(idents) > 1:
-        class_name = '.'.join(idents[:-1])
-        assert class_name in self._all_names
-        if self._is_class[class_name]:
-          ref_path = documentation_path(class_name) + '#%s' % idents[-1]
-
-    if not ref_path:
-      ref_path = documentation_path(master_name)
-
+    ref_path = documentation_path(master_name, self._is_fragment[master_name])
     return os.path.join(relative_path_to_root, ref_path)
 
   def _one_ref(self, match, relative_path_to_root):
@@ -947,6 +967,7 @@
     self._aliases = None
     self._doc = None
     self._guides = None
+    self._namedtuplefields = None
 
     self._bases = None
     self._properties = []
@@ -1030,6 +1051,17 @@
     self._guides = guides
 
   @property
+  def namedtuplefields(self):
+    return self._namedtuplefields
+
+  def set_namedtuplefields(self, py_class):
+    if issubclass(py_class, tuple):
+      if all(
+          hasattr(py_class, attr)
+          for attr in ('_asdict', '_fields', '_make', '_replace')):
+        self._namedtuplefields = py_class._fields
+
+  @property
   def bases(self):
     """Returns a list of `_LinkInfo` objects pointing to the class' parents."""
     return self._bases
@@ -1066,7 +1098,15 @@
   @property
   def properties(self):
     """Returns a list of `_PropertyInfo` describing the class' properties."""
-    return self._properties
+    props_dict = {prop.short_name: prop for prop in self._properties}
+    props = []
+    if self.namedtuplefields:
+      for field in self.namedtuplefields:
+        props.append(props_dict.pop(field))
+
+    props.extend(sorted(props_dict.values()))
+
+    return props
 
   def _add_property(self, short_name, full_name, obj, doc):
     """Adds a `_PropertyInfo` entry to the `properties` list.
@@ -1077,6 +1117,9 @@
       obj: The property object itself
       doc: The property's parsed docstring, a `_DocstringInfo`.
     """
+    # Hide useless namedtuple docs-trings
+    if re.match('Alias for field number [0-9]+', doc.docstring):
+      doc = doc._replace(docstring='', brief='')
     property_info = _PropertyInfo(short_name, full_name, obj, doc)
     self._properties.append(property_info)
 
@@ -1156,6 +1199,7 @@
       py_class: The class object being documented
       parser_config: An instance of ParserConfig.
     """
+    self.set_namedtuplefields(py_class)
     doc_path = documentation_path(self.full_name)
     relative_path = os.path.relpath(
         path='.', start=os.path.dirname(doc_path) or '.')
@@ -1435,7 +1479,7 @@
     self.base_dir = base_dir
     self.defined_in_prefix = 'tensorflow/'
     self.code_url_prefix = (
-        'https://www.tensorflow.org/code/tensorflow/')  # pylint: disable=line-too-long
+        '/code/stable/tensorflow/')  # pylint: disable=line-too-long
 
   def py_name_to_object(self, full_name):
     """Return the Python object for a Python symbol name."""
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 9f6b185..8a41796 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -18,6 +18,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
 import functools
 import os
 import sys
@@ -27,6 +28,12 @@
 from tensorflow.tools.docs import doc_controls
 from tensorflow.tools.docs import parser
 
+# The test needs a real module. `types.ModuleType()` doesn't work, as the result
+# is a `builtin` module. Using "parser" here is arbitraty. The tests don't
+# depend on the module contents. At this point in the process the public api
+# has already been extracted.
+test_module = parser
+
 
 def test_function(unused_arg, unused_kwarg='default'):
   """Docstring for test function."""
@@ -190,6 +197,50 @@
     # Make sure this file is contained as the definition location.
     self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
 
+  def test_namedtuple_field_order(self):
+    namedtupleclass = collections.namedtuple('namedtupleclass',
+                                             {'z', 'y', 'x', 'w', 'v', 'u'})
+
+    index = {
+        'namedtupleclass': namedtupleclass,
+        'namedtupleclass.u': namedtupleclass.u,
+        'namedtupleclass.v': namedtupleclass.v,
+        'namedtupleclass.w': namedtupleclass.w,
+        'namedtupleclass.x': namedtupleclass.x,
+        'namedtupleclass.y': namedtupleclass.y,
+        'namedtupleclass.z': namedtupleclass.z,
+    }
+
+    visitor = DummyVisitor(index=index, duplicate_of={})
+
+    reference_resolver = parser.ReferenceResolver.from_visitor(
+        visitor=visitor, doc_index={}, py_module_names=['tf'])
+
+    tree = {'namedtupleclass': {'u', 'v', 'w', 'x', 'y', 'z'}}
+    parser_config = parser.ParserConfig(
+        reference_resolver=reference_resolver,
+        duplicates={},
+        duplicate_of={},
+        tree=tree,
+        index=index,
+        reverse_index={},
+        guide_index={},
+        base_dir='/')
+
+    page_info = parser.docs_for_object(
+        full_name='namedtupleclass',
+        py_object=namedtupleclass,
+        parser_config=parser_config)
+
+    # Each namedtiple field has a docstring of the form:
+    #   'Alias for field number ##'. These props are returned sorted.
+
+    def sort_key(prop_info):
+      return int(prop_info.obj.__doc__.split(' ')[-1])
+
+    self.assertSequenceEqual(page_info.properties,
+                             sorted(page_info.properties, key=sort_key))
+
   def test_docs_for_class_should_skip(self):
 
     class Parent(object):
@@ -289,15 +340,16 @@
     self.assertEqual('my_method', page_info.methods[0].short_name)
 
   def test_docs_for_module(self):
-    # Get the current module.
-    module = sys.modules[__name__]
 
     index = {
-        'TestModule': module,
-        'TestModule.test_function': test_function,
+        'TestModule':
+            test_module,
+        'TestModule.test_function':
+            test_function,
         'TestModule.test_function_with_args_kwargs':
-        test_function_with_args_kwargs,
-        'TestModule.TestClass': TestClass,
+            test_function_with_args_kwargs,
+        'TestModule.TestClass':
+            TestClass,
     }
 
     visitor = DummyVisitor(index=index, duplicate_of={})
@@ -320,11 +372,13 @@
         base_dir='/')
 
     page_info = parser.docs_for_object(
-        full_name='TestModule', py_object=module, parser_config=parser_config)
+        full_name='TestModule',
+        py_object=test_module,
+        parser_config=parser_config)
 
     # Make sure the brief docstring is present
-    self.assertEqual(tf_inspect.getdoc(module).split('\n')[0],
-                     page_info.doc.brief)
+    self.assertEqual(
+        tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)
 
     # Make sure that the members are there
     funcs = {f_info.obj for f_info in page_info.functions}
@@ -333,8 +387,9 @@
     classes = {cls_info.obj for cls_info in page_info.classes}
     self.assertEqual({TestClass}, classes)
 
-    # Make sure this file is contained as the definition location.
-    self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+    # Make sure the module's file is contained as the definition location.
+    self.assertEqual(
+        os.path.relpath(test_module.__file__, '/'), page_info.defined_in.path)
 
   def test_docs_for_function(self):
     index = {
@@ -450,6 +505,7 @@
 
     duplicate_of = {'tf.third': 'tf.fourth'}
     index = {
+        'tf': test_module,
         'tf.fancy': test_function_with_fancy_docstring,
         'tf.reference': HasOneMember,
         'tf.reference.foo': HasOneMember.foo,
@@ -476,20 +532,18 @@
                      'NumPy has nothing as awesome as this function.\n')
 
   def test_generate_index(self):
-    module = sys.modules[__name__]
 
     index = {
-        'TestModule': module,
-        'test_function': test_function,
-        'TestModule.test_function': test_function,
-        'TestModule.TestClass': TestClass,
-        'TestModule.TestClass.a_method': TestClass.a_method,
-        'TestModule.TestClass.a_property': TestClass.a_property,
-        'TestModule.TestClass.ChildClass': TestClass.ChildClass,
+        'tf': test_module,
+        'tf.TestModule': test_module,
+        'tf.test_function': test_function,
+        'tf.TestModule.test_function': test_function,
+        'tf.TestModule.TestClass': TestClass,
+        'tf.TestModule.TestClass.a_method': TestClass.a_method,
+        'tf.TestModule.TestClass.a_property': TestClass.a_property,
+        'tf.TestModule.TestClass.ChildClass': TestClass.ChildClass,
     }
-    duplicate_of = {
-        'TestModule.test_function': 'test_function'
-    }
+    duplicate_of = {'tf.TestModule.test_function': 'tf.test_function'}
 
     visitor = DummyVisitor(index=index, duplicate_of=duplicate_of)
 
@@ -508,7 +562,7 @@
     self.assertIn('TestModule.test_function', docs)
     # Leading backtick to make sure it's included top-level.
     # This depends on formatting, but should be stable.
-    self.assertIn('<code>test_function', docs)
+    self.assertIn('<code>tf.test_function', docs)
 
   def test_argspec_for_functools_partial(self):
     # pylint: disable=unused-argument
@@ -620,22 +674,18 @@
 
     duplicate_of = {'AClass': ['AClass2']}
     doc_index = {'doc': you_cant_serialize_this}
-    is_class = {
+    is_fragment = {
         'tf': False,
-        'tf.AClass': True,
-        'tf.AClass2': True,
-        'tf.function': False
-    }
-    is_module = {
-        'tf': True,
+        'tf.VERSION': True,
         'tf.AClass': False,
+        'tf.AClass.method': True,
         'tf.AClass2': False,
         'tf.function': False
     }
     py_module_names = ['tf', 'tfdbg']
 
-    resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class,
-                                        is_module, py_module_names)
+    resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_fragment,
+                                        py_module_names)
 
     outdir = googletest.GetTempDir()
 
@@ -647,6 +697,23 @@
     # There are no __slots__, so all fields are visible in __dict__.
     self.assertEqual(resolver.__dict__, resolver2.__dict__)
 
+  def testIsFreeFunction(self):
+
+    result = parser.is_free_function(test_function, 'test_module.test_function',
+                                     {'test_module': test_module})
+    self.assertTrue(result)
+
+    result = parser.is_free_function(test_function, 'TestClass.test_function',
+                                     {'TestClass': TestClass})
+    self.assertFalse(result)
+
+    result = parser.is_free_function(TestClass, 'TestClass', {})
+    self.assertFalse(result)
+
+    result = parser.is_free_function(test_module, 'test_module', {})
+    self.assertFalse(result)
+
+
 RELU_DOC = """Computes rectified linear: `max(features, 0)`
 
 Args:
@@ -736,6 +803,5 @@
     sig = parser._generate_signature(example_fun, reverse_index={})
     self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"])
 
-
 if __name__ == '__main__':
   googletest.main()
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index aecf753..1a3e796 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -136,7 +136,7 @@
 
   if page_info.properties:
     parts.append('## Properties\n\n')
-    for prop_info in sorted(page_info.properties):
+    for prop_info in page_info.properties:
       h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
       parts.append(h3.format(short_name=prop_info.short_name))
 
@@ -255,8 +255,9 @@
     #                   at least for basic types.
     parts.append('## Other Members\n\n')
 
+    h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
     for item in page_info.other_members:
-      parts.append('`{short_name}`\n\n'.format(**item._asdict()))
+      parts.append(h3.format(**item._asdict()))
 
   return ''.join(parts)
 
diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
index c8dc2a7..d97496c 100644
--- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
+++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc
@@ -92,7 +92,7 @@
     if (!str_util::EndsWith(name_string, print_suffix)) {
       continue;
     }
-    string name = std::string(
+    string name(
         name_string.substr(0, name_string.size() - print_suffix.size()));
     records->push_back({name, min, max});
   }
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index dd95779..b8d6ba0 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -42,8 +42,8 @@
                       const std::vector<NodeDef*>& inputs, GraphDef* graph_def,
                       bool control_dep = false) {
     NodeDef* node_def = graph_def->add_node();
-    node_def->set_name(std::string(name));
-    node_def->set_op(std::string(op));
+    node_def->set_name(string(name));
+    node_def->set_op(string(op));
     if (!control_dep) {
       std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
         node_def->add_input(input->name());
diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc
index 5cae8f8..7efe450 100644
--- a/tensorflow/tools/graph_transforms/transform_graph.cc
+++ b/tensorflow/tools/graph_transforms/transform_graph.cc
@@ -65,19 +65,19 @@
               .GetResult(&remaining, &transform_name);
       if (!found_transform_name) {
         return errors::InvalidArgument("Looking for transform name, but found ",
-                                       std::string(remaining).c_str());
+                                       string(remaining).c_str());
       }
       if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
         state = TRANSFORM_PARAM_NAME;
       } else {
         // Add a transform with no parameters.
-        params_list->push_back({std::string(transform_name), func_parameters});
+        params_list->push_back({string(transform_name), func_parameters});
         transform_name = "";
         state = TRANSFORM_NAME;
       }
     } else if (state == TRANSFORM_PARAM_NAME) {
       if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
-        params_list->push_back({std::string(transform_name), func_parameters});
+        params_list->push_back({string(transform_name), func_parameters});
         transform_name = "";
         state = TRANSFORM_NAME;
       } else {
@@ -92,13 +92,13 @@
         if (!found_parameter_name) {
           return errors::InvalidArgument(
               "Looking for parameter name, but found ",
-              std::string(remaining).c_str());
+              string(remaining).c_str());
         }
         if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
           state = TRANSFORM_PARAM_VALUE;
         } else {
           return errors::InvalidArgument("Looking for =, but found ",
-                                         std::string(remaining).c_str());
+                                         string(remaining).c_str());
         }
       }
     } else if (state == TRANSFORM_PARAM_VALUE) {
@@ -120,10 +120,9 @@
       }
       if (!found_parameter_value) {
         return errors::InvalidArgument("Looking for parameter name, but found ",
-                                       std::string(remaining).c_str());
+                                       string(remaining).c_str());
       }
-      func_parameters[std::string(parameter_name)].push_back(
-          std::string(parameter_value));
+      func_parameters[string(parameter_name)].emplace_back(parameter_value);
       // Eat up any trailing quotes.
       Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
       Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index cb084e4..c715380 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -93,7 +93,7 @@
   } else {
     *prefix = "";
   }
-  *node_name = std::string(node_name_piece);
+  *node_name = string(node_name_piece);
 }
 
 string NodeNameFromInput(const string& input_name) {
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 91c5cd0..50515b0 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -60,16 +60,6 @@
     ":included_headers",
     "//tensorflow:tensorflow_py",
     "//tensorflow/contrib/autograph:autograph",
-    "//tensorflow/contrib/autograph/converters:converters",
-    "//tensorflow/contrib/autograph/core:core",
-    "//tensorflow/contrib/autograph/core:test_lib",
-    "//tensorflow/contrib/autograph/impl:impl",
-    "//tensorflow/contrib/autograph/lang:lang",
-    "//tensorflow/contrib/autograph/operators:operators",
-    "//tensorflow/contrib/autograph/pyct:pyct",
-    "//tensorflow/contrib/autograph/pyct/testing:testing",
-    "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
-    "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers",
     "//tensorflow/contrib/boosted_trees:boosted_trees_pip",
     "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
     "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
@@ -102,6 +92,16 @@
     "//tensorflow/contrib/timeseries:timeseries_pip",
     "//tensorflow/contrib/tpu",
     "//tensorflow/examples/tutorials/mnist:package",
+    # "//tensorflow/python/autograph/converters:converters",
+    # "//tensorflow/python/autograph/core:core",
+    "//tensorflow/python/autograph/core:test_lib",
+    # "//tensorflow/python/autograph/impl:impl",
+    # "//tensorflow/python/autograph/lang:lang",
+    # "//tensorflow/python/autograph/operators:operators",
+    # "//tensorflow/python/autograph/pyct:pyct",
+    # "//tensorflow/python/autograph/pyct/testing:testing",
+    # "//tensorflow/python/autograph/pyct/static_analysis:static_analysis",
+    "//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
     "//tensorflow/python:cond_v2",
     "//tensorflow/python:distributed_framework_test_lib",
     "//tensorflow/python:meta_graph_testdata",
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 666ea75..c62271c 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -43,8 +43,7 @@
 
 PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
 function is_windows() {
-  # On windows, the shell script is actually running in msys
-  if [[ "${PLATFORM}" =~ (mingw64|msys)_nt* ]]; then
+  if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then
     true
   else
     false
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 61419f2..3102239 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -167,17 +167,21 @@
     # directories for -I
     install_dir = re.sub('/google/protobuf_archive/src', '', install_dir)
 
-    # Copy eigen code into tensorflow/include.
+    # Copy external code headers into tensorflow/include.
     # A symlink would do, but the wheel file that gets created ignores
     # symlink within the directory hierarchy.
     # NOTE(keveman): Figure out how to customize bdist_wheel package so
     # we can do the symlink.
-    if 'tensorflow/include/external/eigen_archive/' in install_dir:
-      extra_dir = install_dir.replace(
-          'tensorflow/include/external/eigen_archive', '')
-      if not os.path.exists(extra_dir):
-        self.mkpath(extra_dir)
-      self.copy_file(header, extra_dir)
+    external_header_locations = [
+        'tensorflow/include/external/eigen_archive/',
+        'tensorflow/include/external/com_google_absl/',
+    ]
+    for location in external_header_locations:
+      if location in install_dir:
+        extra_dir = install_dir.replace(location, '')
+        if not os.path.exists(extra_dir):
+          self.mkpath(extra_dir)
+        self.copy_file(header, extra_dir)
 
     if not os.path.exists(install_dir):
       self.mkpath(install_dir)
@@ -227,6 +231,8 @@
            list(find_files('*.h', 'tensorflow/stream_executor')) +
            list(find_files('*.h', 'google/protobuf_archive/src')) +
            list(find_files('*', 'third_party/eigen3')) +
+           list(find_files('*.h',
+                           'tensorflow/include/external/com_google_absl')) +
            list(find_files('*', 'tensorflow/include/external/eigen_archive')))
 
 setup(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index fdbb1bf..25698da 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -106,11 +106,11 @@
     tf_http_archive(
         name = "com_google_absl",
         urls = [
-            "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
-            "https://github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
+            "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
+            "https://github.com/abseil/abseil-cpp/archive/8ff1374008259719b54a8cb128ef951c02da164c.tar.gz",
         ],
-        sha256 = "cb4e11259742954f88802be6f33c1007c16502d90d68e8898b5e5084264ca8a9",
-        strip_prefix = "abseil-cpp-c075ad321696fa5072e097f0a51e4fe76a6fe13e",
+        sha256 = "006931f9705484041eed65189038f87931a87cff200bb296f94b3d42339c4cd9",
+        strip_prefix = "abseil-cpp-8ff1374008259719b54a8cb128ef951c02da164c",
         build_file = clean_dep("//third_party:com_google_absl.BUILD"),
     )
 
@@ -240,11 +240,11 @@
     tf_http_archive(
         name = "jpeg",
         urls = [
-            "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
-            "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz",
+            "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+            "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
         ],
-        sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde",
-        strip_prefix = "libjpeg-turbo-1.5.3",
+        sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+        strip_prefix = "libjpeg-turbo-2.0.0",
         build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
         system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
     )
@@ -491,11 +491,11 @@
     tf_http_archive(
         name = "llvm",
         urls = [
-            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
-            "https://github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz",
+            "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
+            "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
         ],
-        sha256 = "b8f4ffbcaeea345e2245fd7028c7e960d71c2a2007c20bbfc5d79ecc86992a5e",
-        strip_prefix = "llvm-67bd0d9a0f5597f57f272061fd70f24dffb3d223",
+        sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2",
+        strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b",
         build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
     )